From ab24e5cbd8283ad4ced486e1369484ebf9e3962d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Apr 2023 16:06:04 +0200 Subject: Update --- training/functional.py | 40 ++++------------------------------------ 1 file changed, 4 insertions(+), 36 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 06848cb..c30d1c0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -321,45 +321,13 @@ def loss_step( ) if offset_noise_strength != 0: - solid_image = partial( - make_solid_image, - shape=images.shape[1:], - vae=vae, + offset_noise = torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), dtype=latents.dtype, device=latents.device, generator=generator - ) - - white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" - black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" - - if white_cache_key not in cache: - img_white = solid_image(1) - cache[white_cache_key] = img_white - else: - img_white = cache[white_cache_key] - - if black_cache_key not in cache: - img_black = solid_image(0) - cache[black_cache_key] = img_black - else: - img_black = cache[black_cache_key] - - offset_strength = torch.rand( - (bsz, 1, 1, 1), - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - offset_strength = offset_noise_strength * (offset_strength * 2 - 1) - offset_images = torch.where( - offset_strength >= 0, - img_white.expand(noise.shape), - img_black.expand(noise.shape) - ) - offset_strength = offset_strength.abs().expand(noise.shape) - noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) + ).expand(noise.shape) + noise += offset_noise_strength * offset_noise # Sample a random timestep for each image timesteps = torch.randint( -- cgit v1.2.3-70-g09d2