diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 110 |
1 files changed, 39 insertions, 71 deletions
diff --git a/training/functional.py b/training/functional.py index e7f02cb..68071bc 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -23,7 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
| 23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 25 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
| 26 | from util.noise import perlin_noise | 26 | from util.slerp import slerp |
| 27 | 27 | ||
| 28 | 28 | ||
| 29 | def const(result=None): | 29 | def const(result=None): |
| @@ -270,62 +270,16 @@ def snr_weight(noisy_latents, latents, gamma): | |||
| 270 | ) | 270 | ) |
| 271 | 271 | ||
| 272 | 272 | ||
| 273 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | 273 | def make_solid_image(color: float, shape, vae, dtype, device, generator): |
| 274 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | 274 | img = torch.tensor( |
| 275 | 275 | [[[[color]]]], | |
| 276 | `DOT_THR` determines when the vectors are too close to parallel. | 276 | dtype=dtype, |
| 277 | If they are too close, then a regular linear interpolation is used. | 277 | device=device |
| 278 | 278 | ).expand(1, *shape) | |
| 279 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | 279 | img = img * 2 - 1 |
| 280 | If the input tensors were on a GPU, it moves them back after the computation. | 280 | img = vae.encode(img).latent_dist.sample(generator=generator) |
| 281 | 281 | img *= vae.config.scaling_factor | |
| 282 | `zdim` is the feature dimension over which to compute norms and find angles. | 282 | return img |
| 283 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
| 284 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
| 285 | |||
| 286 | Theory Reference: | ||
| 287 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
| 288 | PyTorch reference: | ||
| 289 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
| 290 | Numpy reference: | ||
| 291 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
| 292 | """ | ||
| 293 | |||
| 294 | # check if we need to move to the cpu | ||
| 295 | if to_cpu: | ||
| 296 | orig_device = v1.device | ||
| 297 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
| 298 | |||
| 299 | # take the dot product between normalized vectors | ||
| 300 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
| 301 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
| 302 | dot = (v1_norm * v2_norm).sum(zdim) | ||
| 303 | |||
| 304 | for _ in range(len(dot.shape), len(v1.shape)): | ||
| 305 | dot = dot[..., None] | ||
| 306 | |||
| 307 | # if the vectors are too close, return a simple linear interpolation | ||
| 308 | if (torch.abs(dot) > DOT_THR).any(): | ||
| 309 | res = (1 - t) * v1 + t * v2 | ||
| 310 | else: | ||
| 311 | # compute the angle terms we need | ||
| 312 | theta = torch.acos(dot) | ||
| 313 | theta_t = theta * t | ||
| 314 | sin_theta = torch.sin(theta) | ||
| 315 | sin_theta_t = torch.sin(theta_t) | ||
| 316 | |||
| 317 | # compute the sine scaling terms for the vectors | ||
| 318 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
| 319 | s2 = sin_theta_t / sin_theta | ||
| 320 | |||
| 321 | # interpolate the vectors | ||
| 322 | res = s1 * v1 + s2 * v2 | ||
| 323 | |||
| 324 | # check if we need to move them back to the original device | ||
| 325 | if to_cpu: | ||
| 326 | res.to(orig_device) | ||
| 327 | |||
| 328 | return res | ||
| 329 | 283 | ||
| 330 | 284 | ||
| 331 | def loss_step( | 285 | def loss_step( |
| @@ -361,20 +315,29 @@ def loss_step( | |||
| 361 | ) | 315 | ) |
| 362 | 316 | ||
| 363 | if offset_noise_strength != 0: | 317 | if offset_noise_strength != 0: |
| 364 | cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | 318 | solid_image = partial( |
| 319 | make_solid_image, | ||
| 320 | shape=images.shape[1:], | ||
| 321 | vae=vae, | ||
| 322 | dtype=latents.dtype, | ||
| 323 | device=latents.device, | ||
| 324 | generator=generator | ||
| 325 | ) | ||
| 326 | |||
| 327 | white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | ||
| 328 | black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" | ||
| 365 | 329 | ||
| 366 | if cache_key not in cache: | 330 | if white_cache_key not in cache: |
| 367 | img_white = torch.tensor( | 331 | img_white = solid_image(1) |
| 368 | [[[[1]]]], | 332 | cache[white_cache_key] = img_white |
| 369 | dtype=latents.dtype, | ||
| 370 | device=latents.device | ||
| 371 | ).expand(1, images.shape[1], images.shape[2], images.shape[3]) | ||
| 372 | img_white = img_white * 2 - 1 | ||
| 373 | img_white = vae.encode(img_white).latent_dist.sample(generator=generator) | ||
| 374 | img_white *= vae.config.scaling_factor | ||
| 375 | cache[cache_key] = img_white | ||
| 376 | else: | 333 | else: |
| 377 | img_white = cache[cache_key] | 334 | img_white = cache[white_cache_key] |
| 335 | |||
| 336 | if black_cache_key not in cache: | ||
| 337 | img_black = solid_image(0) | ||
| 338 | cache[black_cache_key] = img_black | ||
| 339 | else: | ||
| 340 | img_black = cache[black_cache_key] | ||
| 378 | 341 | ||
| 379 | offset_strength = torch.rand( | 342 | offset_strength = torch.rand( |
| 380 | (bsz, 1, 1, 1), | 343 | (bsz, 1, 1, 1), |
| @@ -384,8 +347,13 @@ def loss_step( | |||
| 384 | generator=generator | 347 | generator=generator |
| 385 | ) | 348 | ) |
| 386 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) | 349 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) |
| 387 | offset_strength = offset_strength.expand(noise.shape) | 350 | offset_images = torch.where( |
| 388 | noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) | 351 | offset_strength >= 0, |
| 352 | img_white.expand(noise.shape), | ||
| 353 | img_black.expand(noise.shape) | ||
| 354 | ) | ||
| 355 | offset_strength = offset_strength.abs().expand(noise.shape) | ||
| 356 | noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) | ||
| 389 | 357 | ||
| 390 | # Sample a random timestep for each image | 358 | # Sample a random timestep for each image |
| 391 | timesteps = torch.randint( | 359 | timesteps = torch.randint( |
