From 51394430b6b142eb21641f251b5fe32cdf802ab8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 1 Mar 2023 14:10:23 +0100 Subject: Changed low freq noise --- training/functional.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 990c4cd..8ea40bb 100644 --- a/training/functional.py +++ b/training/functional.py @@ -268,35 +268,22 @@ def loss_step( generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None # Sample noise that we'll add to the latents - if low_freq_noise == 0: - noise = torch.randn( - latents.shape, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - else: - noise = (1 - low_freq_noise) * torch.randn( - latents.shape, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) + low_freq_noise * torch.randn( + noise = torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) + + if low_freq_noise != 0: + noise *= 1 - low_freq_noise + low_freq_noise * torch.randn( latents.shape[0], latents.shape[1], 1, 1, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) - # noise += low_freq_noise * torch.randn( - # bsz, 1, 1, 1, - # dtype=latents.dtype, - # layout=latents.layout, - # device=latents.device, - # generator=generator - # ) # Sample a random timestep for each image timesteps = torch.randint( -- cgit v1.2.3-54-g00ecf