diff options
-rw-r--r-- | training/functional.py | 97 |
1 files changed, 89 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index 96ecbc1..e7f02cb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -270,6 +270,64 @@ def snr_weight(noisy_latents, latents, gamma): | |||
270 | ) | 270 | ) |
271 | 271 | ||
272 | 272 | ||
273 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | ||
274 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | ||
275 | |||
276 | `DOT_THR` determines when the vectors are too close to parallel. | ||
277 | If they are too close, then a regular linear interpolation is used. | ||
278 | |||
279 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | ||
280 | If the input tensors were on a GPU, it moves them back after the computation. | ||
281 | |||
282 | `zdim` is the feature dimension over which to compute norms and find angles. | ||
283 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
284 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
285 | |||
286 | Theory Reference: | ||
287 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
288 | PyTorch reference: | ||
289 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
290 | Numpy reference: | ||
291 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
292 | """ | ||
293 | |||
294 | # check if we need to move to the cpu | ||
295 | if to_cpu: | ||
296 | orig_device = v1.device | ||
297 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
298 | |||
299 | # take the dot product between normalized vectors | ||
300 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
301 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
302 | dot = (v1_norm * v2_norm).sum(zdim) | ||
303 | |||
304 | for _ in range(len(dot.shape), len(v1.shape)): | ||
305 | dot = dot[..., None] | ||
306 | |||
307 | # if the vectors are too close, return a simple linear interpolation | ||
308 | if (torch.abs(dot) > DOT_THR).any(): | ||
309 | res = (1 - t) * v1 + t * v2 | ||
310 | else: | ||
311 | # compute the angle terms we need | ||
312 | theta = torch.acos(dot) | ||
313 | theta_t = theta * t | ||
314 | sin_theta = torch.sin(theta) | ||
315 | sin_theta_t = torch.sin(theta_t) | ||
316 | |||
317 | # compute the sine scaling terms for the vectors | ||
318 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
319 | s2 = sin_theta_t / sin_theta | ||
320 | |||
321 | # interpolate the vectors | ||
322 | res = s1 * v1 + s2 * v2 | ||
323 | |||
324 | # check if we need to move them back to the original device | ||
325 | if to_cpu: | ||
326 | res.to(orig_device) | ||
327 | |||
328 | return res | ||
329 | |||
330 | |||
273 | def loss_step( | 331 | def loss_step( |
274 | vae: AutoencoderKL, | 332 | vae: AutoencoderKL, |
275 | noise_scheduler: SchedulerMixin, | 333 | noise_scheduler: SchedulerMixin, |
@@ -279,10 +337,11 @@ def loss_step( | |||
279 | prior_loss_weight: float, | 337 | prior_loss_weight: float, |
280 | seed: int, | 338 | seed: int, |
281 | offset_noise_strength: float, | 339 | offset_noise_strength: float, |
340 | min_snr_gamma: int, | ||
282 | step: int, | 341 | step: int, |
283 | batch: dict[str, Any], | 342 | batch: dict[str, Any], |
343 | cache: dict[Any, Any], | ||
284 | eval: bool = False, | 344 | eval: bool = False, |
285 | min_snr_gamma: int = 5, | ||
286 | ): | 345 | ): |
287 | images = batch["pixel_values"] | 346 | images = batch["pixel_values"] |
288 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 347 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
@@ -302,13 +361,31 @@ def loss_step( | |||
302 | ) | 361 | ) |
303 | 362 | ||
304 | if offset_noise_strength != 0: | 363 | if offset_noise_strength != 0: |
305 | offset_noise = torch.randn( | 364 | cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" |
306 | (latents.shape[0], latents.shape[1], 1, 1), | 365 | |
366 | if cache_key not in cache: | ||
367 | img_white = torch.tensor( | ||
368 | [[[[1]]]], | ||
369 | dtype=latents.dtype, | ||
370 | device=latents.device | ||
371 | ).expand(1, images.shape[1], images.shape[2], images.shape[3]) | ||
372 | img_white = img_white * 2 - 1 | ||
373 | img_white = vae.encode(img_white).latent_dist.sample(generator=generator) | ||
374 | img_white *= vae.config.scaling_factor | ||
375 | cache[cache_key] = img_white | ||
376 | else: | ||
377 | img_white = cache[cache_key] | ||
378 | |||
379 | offset_strength = torch.rand( | ||
380 | (bsz, 1, 1, 1), | ||
307 | dtype=latents.dtype, | 381 | dtype=latents.dtype, |
382 | layout=latents.layout, | ||
308 | device=latents.device, | 383 | device=latents.device, |
309 | generator=generator | 384 | generator=generator |
310 | ).expand(noise.shape) | 385 | ) |
311 | noise += offset_noise_strength * offset_noise | 386 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) |
387 | offset_strength = offset_strength.expand(noise.shape) | ||
388 | noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) | ||
312 | 389 | ||
313 | # Sample a random timestep for each image | 390 | # Sample a random timestep for each image |
314 | timesteps = torch.randint( | 391 | timesteps = torch.randint( |
@@ -382,7 +459,8 @@ def loss_step( | |||
382 | 459 | ||
383 | 460 | ||
384 | class LossCallable(Protocol): | 461 | class LossCallable(Protocol): |
385 | def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... | 462 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], |
463 | eval: bool = False) -> Tuple[Any, Any, int]: ... | ||
386 | 464 | ||
387 | 465 | ||
388 | def train_loop( | 466 | def train_loop( |
@@ -407,6 +485,7 @@ def train_loop( | |||
407 | num_val_steps = num_val_steps_per_epoch * num_epochs | 485 | num_val_steps = num_val_steps_per_epoch * num_epochs |
408 | 486 | ||
409 | global_step = 0 | 487 | global_step = 0 |
488 | cache = {} | ||
410 | 489 | ||
411 | avg_loss = AverageMeter() | 490 | avg_loss = AverageMeter() |
412 | avg_acc = AverageMeter() | 491 | avg_acc = AverageMeter() |
@@ -476,7 +555,7 @@ def train_loop( | |||
476 | 555 | ||
477 | with on_train(epoch): | 556 | with on_train(epoch): |
478 | for step, batch in enumerate(train_dataloader): | 557 | for step, batch in enumerate(train_dataloader): |
479 | loss, acc, bsz = loss_step(step, batch) | 558 | loss, acc, bsz = loss_step(step, batch, cache) |
480 | loss /= gradient_accumulation_steps | 559 | loss /= gradient_accumulation_steps |
481 | 560 | ||
482 | accelerator.backward(loss) | 561 | accelerator.backward(loss) |
@@ -541,7 +620,7 @@ def train_loop( | |||
541 | 620 | ||
542 | with torch.inference_mode(), on_eval(): | 621 | with torch.inference_mode(), on_eval(): |
543 | for step, batch in enumerate(val_dataloader): | 622 | for step, batch in enumerate(val_dataloader): |
544 | loss, acc, bsz = loss_step(step, batch, True) | 623 | loss, acc, bsz = loss_step(step, batch, cache, True) |
545 | 624 | ||
546 | loss = loss.detach_() | 625 | loss = loss.detach_() |
547 | acc = acc.detach_() | 626 | acc = acc.detach_() |
@@ -633,6 +712,7 @@ def train( | |||
633 | guidance_scale: float = 0.0, | 712 | guidance_scale: float = 0.0, |
634 | prior_loss_weight: float = 1.0, | 713 | prior_loss_weight: float = 1.0, |
635 | offset_noise_strength: float = 0.15, | 714 | offset_noise_strength: float = 0.15, |
715 | min_snr_gamma: int = 5, | ||
636 | **kwargs, | 716 | **kwargs, |
637 | ): | 717 | ): |
638 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 718 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -665,6 +745,7 @@ def train( | |||
665 | prior_loss_weight, | 745 | prior_loss_weight, |
666 | seed, | 746 | seed, |
667 | offset_noise_strength, | 747 | offset_noise_strength, |
748 | min_snr_gamma, | ||
668 | ) | 749 | ) |
669 | 750 | ||
670 | if accelerator.is_main_process: | 751 | if accelerator.is_main_process: |