diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /models/convnext | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'models/convnext')
-rw-r--r-- | models/convnext/discriminator.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py | |||
@@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||
5 | from torch.nn import functional as F | 5 | from torch.nn import functional as F |
6 | 6 | ||
7 | 7 | ||
8 | class ConvNeXtDiscriminator(): | 8 | class ConvNeXtDiscriminator: |
9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: |
10 | self.net = model | 10 | self.net = model |
11 | 11 | ||
@@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): | |||
22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) |
23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) |
24 | 24 | ||
25 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | 25 | img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) |
26 | 26 | ||
27 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | 27 | img = F.interpolate( |
28 | img, | ||
29 | size=(self.input_size, self.input_size), | ||
30 | mode="bicubic", | ||
31 | align_corners=True, | ||
32 | ) | ||
28 | pred = self.net(img) | 33 | pred = self.net(img) |
29 | return pred | 34 | return pred |