summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py33
1 files changed, 21 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py
index 38dd59f..e7e1eb3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -324,6 +324,7 @@ def loss_step(
324 prior_loss_weight: float, 324 prior_loss_weight: float,
325 seed: int, 325 seed: int,
326 offset_noise_strength: float, 326 offset_noise_strength: float,
327 input_pertubation: float,
327 disc: Optional[ConvNeXtDiscriminator], 328 disc: Optional[ConvNeXtDiscriminator],
328 min_snr_gamma: int, 329 min_snr_gamma: int,
329 step: int, 330 step: int,
@@ -337,7 +338,7 @@ def loss_step(
337 338
338 # Convert images to latent space 339 # Convert images to latent space
339 latents = vae.encode(images).latent_dist.sample(generator=generator) 340 latents = vae.encode(images).latent_dist.sample(generator=generator)
340 latents *= vae.config.scaling_factor 341 latents = latents * vae.config.scaling_factor
341 342
342 # Sample noise that we'll add to the latents 343 # Sample noise that we'll add to the latents
343 noise = torch.randn( 344 noise = torch.randn(
@@ -355,7 +356,10 @@ def loss_step(
355 device=latents.device, 356 device=latents.device,
356 generator=generator 357 generator=generator
357 ).expand(noise.shape) 358 ).expand(noise.shape)
358 noise += offset_noise_strength * offset_noise 359 noise = noise + offset_noise_strength * offset_noise
360
361 if input_pertubation != 0:
362 new_noise = noise + input_pertubation * torch.randn_like(noise)
359 363
360 # Sample a random timestep for each image 364 # Sample a random timestep for each image
361 timesteps = torch.randint( 365 timesteps = torch.randint(
@@ -369,7 +373,10 @@ def loss_step(
369 373
370 # Add noise to the latents according to the noise magnitude at each timestep 374 # Add noise to the latents according to the noise magnitude at each timestep
371 # (this is the forward diffusion process) 375 # (this is the forward diffusion process)
372 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 376 if input_pertubation != 0:
377 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
378 else:
379 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
373 noisy_latents = noisy_latents.to(dtype=unet.dtype) 380 noisy_latents = noisy_latents.to(dtype=unet.dtype)
374 381
375 # Get the text embedding for conditioning 382 # Get the text embedding for conditioning
@@ -381,7 +388,7 @@ def loss_step(
381 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 388 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
382 389
383 # Predict the noise residual 390 # Predict the noise residual
384 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 391 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
385 392
386 if guidance_scale != 0: 393 if guidance_scale != 0:
387 uncond_encoder_hidden_states = get_extended_embeddings( 394 uncond_encoder_hidden_states = get_extended_embeddings(
@@ -391,7 +398,7 @@ def loss_step(
391 ) 398 )
392 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) 399 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
393 400
394 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 401 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0]
395 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 402 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
396 403
397 # Get the target for loss depending on the prediction type 404 # Get the target for loss depending on the prediction type
@@ -424,9 +431,9 @@ def loss_step(
424 431
425 if disc is not None: 432 if disc is not None:
426 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) 433 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
427 rec_latent /= vae.config.scaling_factor 434 rec_latent = rec_latent / vae.config.scaling_factor
428 rec_latent = rec_latent.to(dtype=vae.dtype) 435 rec_latent = rec_latent.to(dtype=vae.dtype)
429 rec = vae.decode(rec_latent).sample 436 rec = vae.decode(rec_latent, return_dict=False)[0]
430 loss = 1 - disc.get_score(rec) 437 loss = 1 - disc.get_score(rec)
431 438
432 if min_snr_gamma != 0: 439 if min_snr_gamma != 0:
@@ -434,7 +441,7 @@ def loss_step(
434 mse_loss_weights = ( 441 mse_loss_weights = (
435 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 442 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
436 ) 443 )
437 loss *= mse_loss_weights 444 loss = loss * mse_loss_weights
438 445
439 loss = loss.mean() 446 loss = loss.mean()
440 447
@@ -539,7 +546,7 @@ def train_loop(
539 with on_train(cycle): 546 with on_train(cycle):
540 for step, batch in enumerate(train_dataloader): 547 for step, batch in enumerate(train_dataloader):
541 loss, acc, bsz = loss_step(step, batch, cache) 548 loss, acc, bsz = loss_step(step, batch, cache)
542 loss /= gradient_accumulation_steps 549 loss = loss / gradient_accumulation_steps
543 550
544 accelerator.backward(loss) 551 accelerator.backward(loss)
545 552
@@ -598,7 +605,7 @@ def train_loop(
598 with torch.inference_mode(), on_eval(): 605 with torch.inference_mode(), on_eval():
599 for step, batch in enumerate(val_dataloader): 606 for step, batch in enumerate(val_dataloader):
600 loss, acc, bsz = loss_step(step, batch, cache, True) 607 loss, acc, bsz = loss_step(step, batch, cache, True)
601 loss /= gradient_accumulation_steps 608 loss = loss / gradient_accumulation_steps
602 609
603 cur_loss_val.update(loss.item(), bsz) 610 cur_loss_val.update(loss.item(), bsz)
604 cur_acc_val.update(acc.item(), bsz) 611 cur_acc_val.update(acc.item(), bsz)
@@ -684,7 +691,8 @@ def train(
684 global_step_offset: int = 0, 691 global_step_offset: int = 0,
685 guidance_scale: float = 0.0, 692 guidance_scale: float = 0.0,
686 prior_loss_weight: float = 1.0, 693 prior_loss_weight: float = 1.0,
687 offset_noise_strength: float = 0.15, 694 offset_noise_strength: float = 0.01,
695 input_pertubation: float = 0.1,
688 disc: Optional[ConvNeXtDiscriminator] = None, 696 disc: Optional[ConvNeXtDiscriminator] = None,
689 min_snr_gamma: int = 5, 697 min_snr_gamma: int = 5,
690 avg_loss: AverageMeter = AverageMeter(), 698 avg_loss: AverageMeter = AverageMeter(),
@@ -704,7 +712,7 @@ def train(
704 712
705 if compile_unet: 713 if compile_unet:
706 unet = torch.compile(unet, backend='hidet') 714 unet = torch.compile(unet, backend='hidet')
707 # unet = torch.compile(unet) 715 # unet = torch.compile(unet, mode="reduce-overhead")
708 716
709 callbacks = strategy.callbacks( 717 callbacks = strategy.callbacks(
710 accelerator=accelerator, 718 accelerator=accelerator,
@@ -727,6 +735,7 @@ def train(
727 prior_loss_weight, 735 prior_loss_weight,
728 seed, 736 seed,
729 offset_noise_strength, 737 offset_noise_strength,
738 input_pertubation,
730 disc, 739 disc,
731 min_snr_gamma, 740 min_snr_gamma,
732 ) 741 )