diff options
Diffstat (limited to 'models/convnext')
| -rw-r--r-- | models/convnext/discriminator.py | 11 | 
1 files changed, 8 insertions, 3 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py  | |||
| @@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||
| 5 | from torch.nn import functional as F | 5 | from torch.nn import functional as F | 
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | class ConvNeXtDiscriminator(): | 8 | class ConvNeXtDiscriminator: | 
| 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | 
| 10 | self.net = model | 10 | self.net = model | 
| 11 | 11 | ||
| @@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): | |||
| 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | 
| 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | 
| 24 | 24 | ||
| 25 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | 25 | img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) | 
| 26 | 26 | ||
| 27 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | 27 | img = F.interpolate( | 
| 28 | img, | ||
| 29 | size=(self.input_size, self.input_size), | ||
| 30 | mode="bicubic", | ||
| 31 | align_corners=True, | ||
| 32 | ) | ||
| 28 | pred = self.net(img) | 33 | pred = self.net(img) | 
| 29 | return pred | 34 | return pred | 
