From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- models/convnext/discriminator.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'models/convnext') diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py @@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.nn import functional as F -class ConvNeXtDiscriminator(): +class ConvNeXtDiscriminator: def __init__(self, model: ConvNeXt, input_size: int) -> None: self.net = model @@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) img_std = self.img_std.to(device=img.device, dtype=img.dtype) - img = ((img + 1.) / 2.).sub(img_mean).div(img_std) + img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) - img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) + img = F.interpolate( + img, + size=(self.input_size, self.input_size), + mode="bicubic", + align_corners=True, + ) pred = self.net(img) return pred -- cgit v1.2.3-54-g00ecf