summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py97
1 files changed, 89 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py
index 96ecbc1..e7f02cb 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -270,6 +270,64 @@ def snr_weight(noisy_latents, latents, gamma):
270 ) 270 )
271 271
272 272
273def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1):
274 """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.
275
276 `DOT_THR` determines when the vectors are too close to parallel.
277 If they are too close, then a regular linear interpolation is used.
278
279 `to_cpu` is a flag that optionally computes SLERP on the CPU.
280 If the input tensors were on a GPU, it moves them back after the computation.
281
282 `zdim` is the feature dimension over which to compute norms and find angles.
283 For example: if a sequence of 5 vectors is input with shape [5, 768]
284 Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.
285
286 Theory Reference:
287 https://splines.readthedocs.io/en/latest/rotation/slerp.html
288 PyTorch reference:
289 https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
290 Numpy reference:
291 https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
292 """
293
294 # check if we need to move to the cpu
295 if to_cpu:
296 orig_device = v1.device
297 v1, v2 = v1.to('cpu'), v2.to('cpu')
298
299 # take the dot product between normalized vectors
300 v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
301 v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
302 dot = (v1_norm * v2_norm).sum(zdim)
303
304 for _ in range(len(dot.shape), len(v1.shape)):
305 dot = dot[..., None]
306
307 # if the vectors are too close, return a simple linear interpolation
308 if (torch.abs(dot) > DOT_THR).any():
309 res = (1 - t) * v1 + t * v2
310 else:
311 # compute the angle terms we need
312 theta = torch.acos(dot)
313 theta_t = theta * t
314 sin_theta = torch.sin(theta)
315 sin_theta_t = torch.sin(theta_t)
316
317 # compute the sine scaling terms for the vectors
318 s1 = torch.sin(theta - theta_t) / sin_theta
319 s2 = sin_theta_t / sin_theta
320
321 # interpolate the vectors
322 res = s1 * v1 + s2 * v2
323
324 # check if we need to move them back to the original device
325 if to_cpu:
326 res.to(orig_device)
327
328 return res
329
330
273def loss_step( 331def loss_step(
274 vae: AutoencoderKL, 332 vae: AutoencoderKL,
275 noise_scheduler: SchedulerMixin, 333 noise_scheduler: SchedulerMixin,
@@ -279,10 +337,11 @@ def loss_step(
279 prior_loss_weight: float, 337 prior_loss_weight: float,
280 seed: int, 338 seed: int,
281 offset_noise_strength: float, 339 offset_noise_strength: float,
340 min_snr_gamma: int,
282 step: int, 341 step: int,
283 batch: dict[str, Any], 342 batch: dict[str, Any],
343 cache: dict[Any, Any],
284 eval: bool = False, 344 eval: bool = False,
285 min_snr_gamma: int = 5,
286): 345):
287 images = batch["pixel_values"] 346 images = batch["pixel_values"]
288 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None 347 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
@@ -302,13 +361,31 @@ def loss_step(
302 ) 361 )
303 362
304 if offset_noise_strength != 0: 363 if offset_noise_strength != 0:
305 offset_noise = torch.randn( 364 cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}"
306 (latents.shape[0], latents.shape[1], 1, 1), 365
366 if cache_key not in cache:
367 img_white = torch.tensor(
368 [[[[1]]]],
369 dtype=latents.dtype,
370 device=latents.device
371 ).expand(1, images.shape[1], images.shape[2], images.shape[3])
372 img_white = img_white * 2 - 1
373 img_white = vae.encode(img_white).latent_dist.sample(generator=generator)
374 img_white *= vae.config.scaling_factor
375 cache[cache_key] = img_white
376 else:
377 img_white = cache[cache_key]
378
379 offset_strength = torch.rand(
380 (bsz, 1, 1, 1),
307 dtype=latents.dtype, 381 dtype=latents.dtype,
382 layout=latents.layout,
308 device=latents.device, 383 device=latents.device,
309 generator=generator 384 generator=generator
310 ).expand(noise.shape) 385 )
311 noise += offset_noise_strength * offset_noise 386 offset_strength = offset_noise_strength * (offset_strength * 2 - 1)
387 offset_strength = offset_strength.expand(noise.shape)
388 noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2))
312 389
313 # Sample a random timestep for each image 390 # Sample a random timestep for each image
314 timesteps = torch.randint( 391 timesteps = torch.randint(
@@ -382,7 +459,8 @@ def loss_step(
382 459
383 460
384class LossCallable(Protocol): 461class LossCallable(Protocol):
385 def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... 462 def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any],
463 eval: bool = False) -> Tuple[Any, Any, int]: ...
386 464
387 465
388def train_loop( 466def train_loop(
@@ -407,6 +485,7 @@ def train_loop(
407 num_val_steps = num_val_steps_per_epoch * num_epochs 485 num_val_steps = num_val_steps_per_epoch * num_epochs
408 486
409 global_step = 0 487 global_step = 0
488 cache = {}
410 489
411 avg_loss = AverageMeter() 490 avg_loss = AverageMeter()
412 avg_acc = AverageMeter() 491 avg_acc = AverageMeter()
@@ -476,7 +555,7 @@ def train_loop(
476 555
477 with on_train(epoch): 556 with on_train(epoch):
478 for step, batch in enumerate(train_dataloader): 557 for step, batch in enumerate(train_dataloader):
479 loss, acc, bsz = loss_step(step, batch) 558 loss, acc, bsz = loss_step(step, batch, cache)
480 loss /= gradient_accumulation_steps 559 loss /= gradient_accumulation_steps
481 560
482 accelerator.backward(loss) 561 accelerator.backward(loss)
@@ -541,7 +620,7 @@ def train_loop(
541 620
542 with torch.inference_mode(), on_eval(): 621 with torch.inference_mode(), on_eval():
543 for step, batch in enumerate(val_dataloader): 622 for step, batch in enumerate(val_dataloader):
544 loss, acc, bsz = loss_step(step, batch, True) 623 loss, acc, bsz = loss_step(step, batch, cache, True)
545 624
546 loss = loss.detach_() 625 loss = loss.detach_()
547 acc = acc.detach_() 626 acc = acc.detach_()
@@ -633,6 +712,7 @@ def train(
633 guidance_scale: float = 0.0, 712 guidance_scale: float = 0.0,
634 prior_loss_weight: float = 1.0, 713 prior_loss_weight: float = 1.0,
635 offset_noise_strength: float = 0.15, 714 offset_noise_strength: float = 0.15,
715 min_snr_gamma: int = 5,
636 **kwargs, 716 **kwargs,
637): 717):
638 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 718 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -665,6 +745,7 @@ def train(
665 prior_loss_weight, 745 prior_loss_weight,
666 seed, 746 seed,
667 offset_noise_strength, 747 offset_noise_strength,
748 min_snr_gamma,
668 ) 749 )
669 750
670 if accelerator.is_main_process: 751 if accelerator.is_main_process: