From f0a171923cc8240177302f3dccb6177a2c646ab3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 11 May 2023 18:37:43 +0200 Subject: Update --- training/functional.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index eae5681..49c21c7 100644 --- a/training/functional.py +++ b/training/functional.py @@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler from util.slerp import slerp +from util.noise import perlin_noise def const(result=None): @@ -350,28 +351,33 @@ def loss_step( device=latents.device, generator=generator ) + applied_noise = noise if offset_noise_strength != 0: - offset_noise = torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), + applied_noise = applied_noise + offset_noise_strength * perlin_noise( + latents.shape, + res=1, + octaves=4, dtype=latents.dtype, device=latents.device, generator=generator - ).expand(noise.shape) - noise = noise + offset_noise_strength * offset_noise + ) if input_pertubation != 0: - new_noise = noise + input_pertubation * torch.randn_like(noise) + applied_noise = applied_noise + input_pertubation * torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) # Sample a random timestep for each image timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - if input_pertubation != 0: - noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning -- cgit v1.2.3-54-g00ecf