diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-06 16:06:04 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-06 16:06:04 +0200 | 
| commit | ab24e5cbd8283ad4ced486e1369484ebf9e3962d (patch) | |
| tree | 7d47b7cd38e7313071ad4a671b14f8a23dcd7389 /training/functional.py | |
| parent | MinSNR code from diffusers (diff) | |
| download | textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.gz textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.bz2 textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.zip | |
Update
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 40 | 
1 files changed, 4 insertions, 36 deletions
| diff --git a/training/functional.py b/training/functional.py index 06848cb..c30d1c0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -321,45 +321,13 @@ def loss_step( | |||
| 321 | ) | 321 | ) | 
| 322 | 322 | ||
| 323 | if offset_noise_strength != 0: | 323 | if offset_noise_strength != 0: | 
| 324 | solid_image = partial( | 324 | offset_noise = torch.randn( | 
| 325 | make_solid_image, | 325 | (latents.shape[0], latents.shape[1], 1, 1), | 
| 326 | shape=images.shape[1:], | ||
| 327 | vae=vae, | ||
| 328 | dtype=latents.dtype, | 326 | dtype=latents.dtype, | 
| 329 | device=latents.device, | 327 | device=latents.device, | 
| 330 | generator=generator | 328 | generator=generator | 
| 331 | ) | 329 | ).expand(noise.shape) | 
| 332 | 330 | noise += offset_noise_strength * offset_noise | |
| 333 | white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | ||
| 334 | black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" | ||
| 335 | |||
| 336 | if white_cache_key not in cache: | ||
| 337 | img_white = solid_image(1) | ||
| 338 | cache[white_cache_key] = img_white | ||
| 339 | else: | ||
| 340 | img_white = cache[white_cache_key] | ||
| 341 | |||
| 342 | if black_cache_key not in cache: | ||
| 343 | img_black = solid_image(0) | ||
| 344 | cache[black_cache_key] = img_black | ||
| 345 | else: | ||
| 346 | img_black = cache[black_cache_key] | ||
| 347 | |||
| 348 | offset_strength = torch.rand( | ||
| 349 | (bsz, 1, 1, 1), | ||
| 350 | dtype=latents.dtype, | ||
| 351 | layout=latents.layout, | ||
| 352 | device=latents.device, | ||
| 353 | generator=generator | ||
| 354 | ) | ||
| 355 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) | ||
| 356 | offset_images = torch.where( | ||
| 357 | offset_strength >= 0, | ||
| 358 | img_white.expand(noise.shape), | ||
| 359 | img_black.expand(noise.shape) | ||
| 360 | ) | ||
| 361 | offset_strength = offset_strength.abs().expand(noise.shape) | ||
| 362 | noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) | ||
| 363 | 331 | ||
| 364 | # Sample a random timestep for each image | 332 | # Sample a random timestep for each image | 
| 365 | timesteps = torch.randint( | 333 | timesteps = torch.randint( | 
