diff options
| -rw-r--r-- | training/functional.py | 97 |
1 files changed, 89 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index 96ecbc1..e7f02cb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -270,6 +270,64 @@ def snr_weight(noisy_latents, latents, gamma): | |||
| 270 | ) | 270 | ) |
| 271 | 271 | ||
| 272 | 272 | ||
| 273 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | ||
| 274 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | ||
| 275 | |||
| 276 | `DOT_THR` determines when the vectors are too close to parallel. | ||
| 277 | If they are too close, then a regular linear interpolation is used. | ||
| 278 | |||
| 279 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | ||
| 280 | If the input tensors were on a GPU, it moves them back after the computation. | ||
| 281 | |||
| 282 | `zdim` is the feature dimension over which to compute norms and find angles. | ||
| 283 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
| 284 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
| 285 | |||
| 286 | Theory Reference: | ||
| 287 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
| 288 | PyTorch reference: | ||
| 289 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
| 290 | Numpy reference: | ||
| 291 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
| 292 | """ | ||
| 293 | |||
| 294 | # check if we need to move to the cpu | ||
| 295 | if to_cpu: | ||
| 296 | orig_device = v1.device | ||
| 297 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
| 298 | |||
| 299 | # take the dot product between normalized vectors | ||
| 300 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
| 301 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
| 302 | dot = (v1_norm * v2_norm).sum(zdim) | ||
| 303 | |||
| 304 | for _ in range(len(dot.shape), len(v1.shape)): | ||
| 305 | dot = dot[..., None] | ||
| 306 | |||
| 307 | # if the vectors are too close, return a simple linear interpolation | ||
| 308 | if (torch.abs(dot) > DOT_THR).any(): | ||
| 309 | res = (1 - t) * v1 + t * v2 | ||
| 310 | else: | ||
| 311 | # compute the angle terms we need | ||
| 312 | theta = torch.acos(dot) | ||
| 313 | theta_t = theta * t | ||
| 314 | sin_theta = torch.sin(theta) | ||
| 315 | sin_theta_t = torch.sin(theta_t) | ||
| 316 | |||
| 317 | # compute the sine scaling terms for the vectors | ||
| 318 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
| 319 | s2 = sin_theta_t / sin_theta | ||
| 320 | |||
| 321 | # interpolate the vectors | ||
| 322 | res = s1 * v1 + s2 * v2 | ||
| 323 | |||
| 324 | # check if we need to move them back to the original device | ||
| 325 | if to_cpu: | ||
| 326 | res.to(orig_device) | ||
| 327 | |||
| 328 | return res | ||
| 329 | |||
| 330 | |||
| 273 | def loss_step( | 331 | def loss_step( |
| 274 | vae: AutoencoderKL, | 332 | vae: AutoencoderKL, |
| 275 | noise_scheduler: SchedulerMixin, | 333 | noise_scheduler: SchedulerMixin, |
| @@ -279,10 +337,11 @@ def loss_step( | |||
| 279 | prior_loss_weight: float, | 337 | prior_loss_weight: float, |
| 280 | seed: int, | 338 | seed: int, |
| 281 | offset_noise_strength: float, | 339 | offset_noise_strength: float, |
| 340 | min_snr_gamma: int, | ||
| 282 | step: int, | 341 | step: int, |
| 283 | batch: dict[str, Any], | 342 | batch: dict[str, Any], |
| 343 | cache: dict[Any, Any], | ||
| 284 | eval: bool = False, | 344 | eval: bool = False, |
| 285 | min_snr_gamma: int = 5, | ||
| 286 | ): | 345 | ): |
| 287 | images = batch["pixel_values"] | 346 | images = batch["pixel_values"] |
| 288 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 347 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
| @@ -302,13 +361,31 @@ def loss_step( | |||
| 302 | ) | 361 | ) |
| 303 | 362 | ||
| 304 | if offset_noise_strength != 0: | 363 | if offset_noise_strength != 0: |
| 305 | offset_noise = torch.randn( | 364 | cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" |
| 306 | (latents.shape[0], latents.shape[1], 1, 1), | 365 | |
| 366 | if cache_key not in cache: | ||
| 367 | img_white = torch.tensor( | ||
| 368 | [[[[1]]]], | ||
| 369 | dtype=latents.dtype, | ||
| 370 | device=latents.device | ||
| 371 | ).expand(1, images.shape[1], images.shape[2], images.shape[3]) | ||
| 372 | img_white = img_white * 2 - 1 | ||
| 373 | img_white = vae.encode(img_white).latent_dist.sample(generator=generator) | ||
| 374 | img_white *= vae.config.scaling_factor | ||
| 375 | cache[cache_key] = img_white | ||
| 376 | else: | ||
| 377 | img_white = cache[cache_key] | ||
| 378 | |||
| 379 | offset_strength = torch.rand( | ||
| 380 | (bsz, 1, 1, 1), | ||
| 307 | dtype=latents.dtype, | 381 | dtype=latents.dtype, |
| 382 | layout=latents.layout, | ||
| 308 | device=latents.device, | 383 | device=latents.device, |
| 309 | generator=generator | 384 | generator=generator |
| 310 | ).expand(noise.shape) | 385 | ) |
| 311 | noise += offset_noise_strength * offset_noise | 386 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) |
| 387 | offset_strength = offset_strength.expand(noise.shape) | ||
| 388 | noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) | ||
| 312 | 389 | ||
| 313 | # Sample a random timestep for each image | 390 | # Sample a random timestep for each image |
| 314 | timesteps = torch.randint( | 391 | timesteps = torch.randint( |
| @@ -382,7 +459,8 @@ def loss_step( | |||
| 382 | 459 | ||
| 383 | 460 | ||
| 384 | class LossCallable(Protocol): | 461 | class LossCallable(Protocol): |
| 385 | def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... | 462 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], |
| 463 | eval: bool = False) -> Tuple[Any, Any, int]: ... | ||
| 386 | 464 | ||
| 387 | 465 | ||
| 388 | def train_loop( | 466 | def train_loop( |
| @@ -407,6 +485,7 @@ def train_loop( | |||
| 407 | num_val_steps = num_val_steps_per_epoch * num_epochs | 485 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| 408 | 486 | ||
| 409 | global_step = 0 | 487 | global_step = 0 |
| 488 | cache = {} | ||
| 410 | 489 | ||
| 411 | avg_loss = AverageMeter() | 490 | avg_loss = AverageMeter() |
| 412 | avg_acc = AverageMeter() | 491 | avg_acc = AverageMeter() |
| @@ -476,7 +555,7 @@ def train_loop( | |||
| 476 | 555 | ||
| 477 | with on_train(epoch): | 556 | with on_train(epoch): |
| 478 | for step, batch in enumerate(train_dataloader): | 557 | for step, batch in enumerate(train_dataloader): |
| 479 | loss, acc, bsz = loss_step(step, batch) | 558 | loss, acc, bsz = loss_step(step, batch, cache) |
| 480 | loss /= gradient_accumulation_steps | 559 | loss /= gradient_accumulation_steps |
| 481 | 560 | ||
| 482 | accelerator.backward(loss) | 561 | accelerator.backward(loss) |
| @@ -541,7 +620,7 @@ def train_loop( | |||
| 541 | 620 | ||
| 542 | with torch.inference_mode(), on_eval(): | 621 | with torch.inference_mode(), on_eval(): |
| 543 | for step, batch in enumerate(val_dataloader): | 622 | for step, batch in enumerate(val_dataloader): |
| 544 | loss, acc, bsz = loss_step(step, batch, True) | 623 | loss, acc, bsz = loss_step(step, batch, cache, True) |
| 545 | 624 | ||
| 546 | loss = loss.detach_() | 625 | loss = loss.detach_() |
| 547 | acc = acc.detach_() | 626 | acc = acc.detach_() |
| @@ -633,6 +712,7 @@ def train( | |||
| 633 | guidance_scale: float = 0.0, | 712 | guidance_scale: float = 0.0, |
| 634 | prior_loss_weight: float = 1.0, | 713 | prior_loss_weight: float = 1.0, |
| 635 | offset_noise_strength: float = 0.15, | 714 | offset_noise_strength: float = 0.15, |
| 715 | min_snr_gamma: int = 5, | ||
| 636 | **kwargs, | 716 | **kwargs, |
| 637 | ): | 717 | ): |
| 638 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 718 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -665,6 +745,7 @@ def train( | |||
| 665 | prior_loss_weight, | 745 | prior_loss_weight, |
| 666 | seed, | 746 | seed, |
| 667 | offset_noise_strength, | 747 | offset_noise_strength, |
| 748 | min_snr_gamma, | ||
| 668 | ) | 749 | ) |
| 669 | 750 | ||
| 670 | if accelerator.is_main_process: | 751 | if accelerator.is_main_process: |
