summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py
index eae5681..49c21c7 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator
27from training.util import AverageMeter 27from training.util import AverageMeter
28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler 28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
29from util.slerp import slerp 29from util.slerp import slerp
30from util.noise import perlin_noise
30 31
31 32
32def const(result=None): 33def const(result=None):
@@ -350,28 +351,33 @@ def loss_step(
350 device=latents.device, 351 device=latents.device,
351 generator=generator 352 generator=generator
352 ) 353 )
354 applied_noise = noise
353 355
354 if offset_noise_strength != 0: 356 if offset_noise_strength != 0:
355 offset_noise = torch.randn( 357 applied_noise = applied_noise + offset_noise_strength * perlin_noise(
356 (latents.shape[0], latents.shape[1], 1, 1), 358 latents.shape,
359 res=1,
360 octaves=4,
357 dtype=latents.dtype, 361 dtype=latents.dtype,
358 device=latents.device, 362 device=latents.device,
359 generator=generator 363 generator=generator
360 ).expand(noise.shape) 364 )
361 noise = noise + offset_noise_strength * offset_noise
362 365
363 if input_pertubation != 0: 366 if input_pertubation != 0:
364 new_noise = noise + input_pertubation * torch.randn_like(noise) 367 applied_noise = applied_noise + input_pertubation * torch.randn(
368 latents.shape,
369 dtype=latents.dtype,
370 layout=latents.layout,
371 device=latents.device,
372 generator=generator
373 )
365 374
366 # Sample a random timestep for each image 375 # Sample a random timestep for each image
367 timesteps, weights = schedule_sampler.sample(bsz, latents.device) 376 timesteps, weights = schedule_sampler.sample(bsz, latents.device)
368 377
369 # Add noise to the latents according to the noise magnitude at each timestep 378 # Add noise to the latents according to the noise magnitude at each timestep
370 # (this is the forward diffusion process) 379 # (this is the forward diffusion process)
371 if input_pertubation != 0: 380 noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps)
372 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
373 else:
374 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
375 noisy_latents = noisy_latents.to(dtype=unet.dtype) 381 noisy_latents = noisy_latents.to(dtype=unet.dtype)
376 382
377 # Get the text embedding for conditioning 383 # Get the text embedding for conditioning