diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 40 |
1 files changed, 4 insertions, 36 deletions
diff --git a/training/functional.py b/training/functional.py index 06848cb..c30d1c0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -321,45 +321,13 @@ def loss_step( | |||
321 | ) | 321 | ) |
322 | 322 | ||
323 | if offset_noise_strength != 0: | 323 | if offset_noise_strength != 0: |
324 | solid_image = partial( | 324 | offset_noise = torch.randn( |
325 | make_solid_image, | 325 | (latents.shape[0], latents.shape[1], 1, 1), |
326 | shape=images.shape[1:], | ||
327 | vae=vae, | ||
328 | dtype=latents.dtype, | 326 | dtype=latents.dtype, |
329 | device=latents.device, | 327 | device=latents.device, |
330 | generator=generator | 328 | generator=generator |
331 | ) | 329 | ).expand(noise.shape) |
332 | 330 | noise += offset_noise_strength * offset_noise | |
333 | white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | ||
334 | black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" | ||
335 | |||
336 | if white_cache_key not in cache: | ||
337 | img_white = solid_image(1) | ||
338 | cache[white_cache_key] = img_white | ||
339 | else: | ||
340 | img_white = cache[white_cache_key] | ||
341 | |||
342 | if black_cache_key not in cache: | ||
343 | img_black = solid_image(0) | ||
344 | cache[black_cache_key] = img_black | ||
345 | else: | ||
346 | img_black = cache[black_cache_key] | ||
347 | |||
348 | offset_strength = torch.rand( | ||
349 | (bsz, 1, 1, 1), | ||
350 | dtype=latents.dtype, | ||
351 | layout=latents.layout, | ||
352 | device=latents.device, | ||
353 | generator=generator | ||
354 | ) | ||
355 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) | ||
356 | offset_images = torch.where( | ||
357 | offset_strength >= 0, | ||
358 | img_white.expand(noise.shape), | ||
359 | img_black.expand(noise.shape) | ||
360 | ) | ||
361 | offset_strength = offset_strength.abs().expand(noise.shape) | ||
362 | noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) | ||
363 | 331 | ||
364 | # Sample a random timestep for each image | 332 | # Sample a random timestep for each image |
365 | timesteps = torch.randint( | 333 | timesteps = torch.randint( |