diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py index eae5681..49c21c7 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator | |||
| 27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler |
| 29 | from util.slerp import slerp | 29 | from util.slerp import slerp |
| 30 | from util.noise import perlin_noise | ||
| 30 | 31 | ||
| 31 | 32 | ||
| 32 | def const(result=None): | 33 | def const(result=None): |
| @@ -350,28 +351,33 @@ def loss_step( | |||
| 350 | device=latents.device, | 351 | device=latents.device, |
| 351 | generator=generator | 352 | generator=generator |
| 352 | ) | 353 | ) |
| 354 | applied_noise = noise | ||
| 353 | 355 | ||
| 354 | if offset_noise_strength != 0: | 356 | if offset_noise_strength != 0: |
| 355 | offset_noise = torch.randn( | 357 | applied_noise = applied_noise + offset_noise_strength * perlin_noise( |
| 356 | (latents.shape[0], latents.shape[1], 1, 1), | 358 | latents.shape, |
| 359 | res=1, | ||
| 360 | octaves=4, | ||
| 357 | dtype=latents.dtype, | 361 | dtype=latents.dtype, |
| 358 | device=latents.device, | 362 | device=latents.device, |
| 359 | generator=generator | 363 | generator=generator |
| 360 | ).expand(noise.shape) | 364 | ) |
| 361 | noise = noise + offset_noise_strength * offset_noise | ||
| 362 | 365 | ||
| 363 | if input_pertubation != 0: | 366 | if input_pertubation != 0: |
| 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 367 | applied_noise = applied_noise + input_pertubation * torch.randn( |
| 368 | latents.shape, | ||
| 369 | dtype=latents.dtype, | ||
| 370 | layout=latents.layout, | ||
| 371 | device=latents.device, | ||
| 372 | generator=generator | ||
| 373 | ) | ||
| 365 | 374 | ||
| 366 | # Sample a random timestep for each image | 375 | # Sample a random timestep for each image |
| 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) | 376 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
| 368 | 377 | ||
| 369 | # Add noise to the latents according to the noise magnitude at each timestep | 378 | # Add noise to the latents according to the noise magnitude at each timestep |
| 370 | # (this is the forward diffusion process) | 379 | # (this is the forward diffusion process) |
| 371 | if input_pertubation != 0: | 380 | noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) |
| 372 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
| 373 | else: | ||
| 374 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 375 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 381 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 376 | 382 | ||
| 377 | # Get the text embedding for conditioning | 383 | # Get the text embedding for conditioning |
