summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-06 16:06:04 +0200
committerVolpeon <git@volpeon.ink>2023-04-06 16:06:04 +0200
commitab24e5cbd8283ad4ced486e1369484ebf9e3962d (patch)
tree7d47b7cd38e7313071ad4a671b14f8a23dcd7389 /training/functional.py
parentMinSNR code from diffusers (diff)
downloadtextual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.gz
textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.bz2
textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py40
1 files changed, 4 insertions, 36 deletions
diff --git a/training/functional.py b/training/functional.py
index 06848cb..c30d1c0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -321,45 +321,13 @@ def loss_step(
321 ) 321 )
322 322
323 if offset_noise_strength != 0: 323 if offset_noise_strength != 0:
324 solid_image = partial( 324 offset_noise = torch.randn(
325 make_solid_image, 325 (latents.shape[0], latents.shape[1], 1, 1),
326 shape=images.shape[1:],
327 vae=vae,
328 dtype=latents.dtype, 326 dtype=latents.dtype,
329 device=latents.device, 327 device=latents.device,
330 generator=generator 328 generator=generator
331 ) 329 ).expand(noise.shape)
332 330 noise += offset_noise_strength * offset_noise
333 white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}"
334 black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}"
335
336 if white_cache_key not in cache:
337 img_white = solid_image(1)
338 cache[white_cache_key] = img_white
339 else:
340 img_white = cache[white_cache_key]
341
342 if black_cache_key not in cache:
343 img_black = solid_image(0)
344 cache[black_cache_key] = img_black
345 else:
346 img_black = cache[black_cache_key]
347
348 offset_strength = torch.rand(
349 (bsz, 1, 1, 1),
350 dtype=latents.dtype,
351 layout=latents.layout,
352 device=latents.device,
353 generator=generator
354 )
355 offset_strength = offset_noise_strength * (offset_strength * 2 - 1)
356 offset_images = torch.where(
357 offset_strength >= 0,
358 img_white.expand(noise.shape),
359 img_black.expand(noise.shape)
360 )
361 offset_strength = offset_strength.abs().expand(noise.shape)
362 noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2))
363 331
364 # Sample a random timestep for each image 332 # Sample a random timestep for each image
365 timesteps = torch.randint( 333 timesteps = torch.randint(