diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 33 | 
1 files changed, 21 insertions, 12 deletions
| diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -324,6 +324,7 @@ def loss_step( | |||
| 324 | prior_loss_weight: float, | 324 | prior_loss_weight: float, | 
| 325 | seed: int, | 325 | seed: int, | 
| 326 | offset_noise_strength: float, | 326 | offset_noise_strength: float, | 
| 327 | input_pertubation: float, | ||
| 327 | disc: Optional[ConvNeXtDiscriminator], | 328 | disc: Optional[ConvNeXtDiscriminator], | 
| 328 | min_snr_gamma: int, | 329 | min_snr_gamma: int, | 
| 329 | step: int, | 330 | step: int, | 
| @@ -337,7 +338,7 @@ def loss_step( | |||
| 337 | 338 | ||
| 338 | # Convert images to latent space | 339 | # Convert images to latent space | 
| 339 | latents = vae.encode(images).latent_dist.sample(generator=generator) | 340 | latents = vae.encode(images).latent_dist.sample(generator=generator) | 
| 340 | latents *= vae.config.scaling_factor | 341 | latents = latents * vae.config.scaling_factor | 
| 341 | 342 | ||
| 342 | # Sample noise that we'll add to the latents | 343 | # Sample noise that we'll add to the latents | 
| 343 | noise = torch.randn( | 344 | noise = torch.randn( | 
| @@ -355,7 +356,10 @@ def loss_step( | |||
| 355 | device=latents.device, | 356 | device=latents.device, | 
| 356 | generator=generator | 357 | generator=generator | 
| 357 | ).expand(noise.shape) | 358 | ).expand(noise.shape) | 
| 358 | noise += offset_noise_strength * offset_noise | 359 | noise = noise + offset_noise_strength * offset_noise | 
| 360 | |||
| 361 | if input_pertubation != 0: | ||
| 362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | ||
| 359 | 363 | ||
| 360 | # Sample a random timestep for each image | 364 | # Sample a random timestep for each image | 
| 361 | timesteps = torch.randint( | 365 | timesteps = torch.randint( | 
| @@ -369,7 +373,10 @@ def loss_step( | |||
| 369 | 373 | ||
| 370 | # Add noise to the latents according to the noise magnitude at each timestep | 374 | # Add noise to the latents according to the noise magnitude at each timestep | 
| 371 | # (this is the forward diffusion process) | 375 | # (this is the forward diffusion process) | 
| 372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 376 | if input_pertubation != 0: | 
| 377 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
| 378 | else: | ||
| 379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 373 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 380 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 
| 374 | 381 | ||
| 375 | # Get the text embedding for conditioning | 382 | # Get the text embedding for conditioning | 
| @@ -381,7 +388,7 @@ def loss_step( | |||
| 381 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 388 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 
| 382 | 389 | ||
| 383 | # Predict the noise residual | 390 | # Predict the noise residual | 
| 384 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 391 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] | 
| 385 | 392 | ||
| 386 | if guidance_scale != 0: | 393 | if guidance_scale != 0: | 
| 387 | uncond_encoder_hidden_states = get_extended_embeddings( | 394 | uncond_encoder_hidden_states = get_extended_embeddings( | 
| @@ -391,7 +398,7 @@ def loss_step( | |||
| 391 | ) | 398 | ) | 
| 392 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 399 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 
| 393 | 400 | ||
| 394 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 401 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] | 
| 395 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 402 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 
| 396 | 403 | ||
| 397 | # Get the target for loss depending on the prediction type | 404 | # Get the target for loss depending on the prediction type | 
| @@ -424,9 +431,9 @@ def loss_step( | |||
| 424 | 431 | ||
| 425 | if disc is not None: | 432 | if disc is not None: | 
| 426 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | 433 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | 
| 427 | rec_latent /= vae.config.scaling_factor | 434 | rec_latent = rec_latent / vae.config.scaling_factor | 
| 428 | rec_latent = rec_latent.to(dtype=vae.dtype) | 435 | rec_latent = rec_latent.to(dtype=vae.dtype) | 
| 429 | rec = vae.decode(rec_latent).sample | 436 | rec = vae.decode(rec_latent, return_dict=False)[0] | 
| 430 | loss = 1 - disc.get_score(rec) | 437 | loss = 1 - disc.get_score(rec) | 
| 431 | 438 | ||
| 432 | if min_snr_gamma != 0: | 439 | if min_snr_gamma != 0: | 
| @@ -434,7 +441,7 @@ def loss_step( | |||
| 434 | mse_loss_weights = ( | 441 | mse_loss_weights = ( | 
| 435 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 442 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 
| 436 | ) | 443 | ) | 
| 437 | loss *= mse_loss_weights | 444 | loss = loss * mse_loss_weights | 
| 438 | 445 | ||
| 439 | loss = loss.mean() | 446 | loss = loss.mean() | 
| 440 | 447 | ||
| @@ -539,7 +546,7 @@ def train_loop( | |||
| 539 | with on_train(cycle): | 546 | with on_train(cycle): | 
| 540 | for step, batch in enumerate(train_dataloader): | 547 | for step, batch in enumerate(train_dataloader): | 
| 541 | loss, acc, bsz = loss_step(step, batch, cache) | 548 | loss, acc, bsz = loss_step(step, batch, cache) | 
| 542 | loss /= gradient_accumulation_steps | 549 | loss = loss / gradient_accumulation_steps | 
| 543 | 550 | ||
| 544 | accelerator.backward(loss) | 551 | accelerator.backward(loss) | 
| 545 | 552 | ||
| @@ -598,7 +605,7 @@ def train_loop( | |||
| 598 | with torch.inference_mode(), on_eval(): | 605 | with torch.inference_mode(), on_eval(): | 
| 599 | for step, batch in enumerate(val_dataloader): | 606 | for step, batch in enumerate(val_dataloader): | 
| 600 | loss, acc, bsz = loss_step(step, batch, cache, True) | 607 | loss, acc, bsz = loss_step(step, batch, cache, True) | 
| 601 | loss /= gradient_accumulation_steps | 608 | loss = loss / gradient_accumulation_steps | 
| 602 | 609 | ||
| 603 | cur_loss_val.update(loss.item(), bsz) | 610 | cur_loss_val.update(loss.item(), bsz) | 
| 604 | cur_acc_val.update(acc.item(), bsz) | 611 | cur_acc_val.update(acc.item(), bsz) | 
| @@ -684,7 +691,8 @@ def train( | |||
| 684 | global_step_offset: int = 0, | 691 | global_step_offset: int = 0, | 
| 685 | guidance_scale: float = 0.0, | 692 | guidance_scale: float = 0.0, | 
| 686 | prior_loss_weight: float = 1.0, | 693 | prior_loss_weight: float = 1.0, | 
| 687 | offset_noise_strength: float = 0.15, | 694 | offset_noise_strength: float = 0.01, | 
| 695 | input_pertubation: float = 0.1, | ||
| 688 | disc: Optional[ConvNeXtDiscriminator] = None, | 696 | disc: Optional[ConvNeXtDiscriminator] = None, | 
| 689 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, | 
| 690 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), | 
| @@ -704,7 +712,7 @@ def train( | |||
| 704 | 712 | ||
| 705 | if compile_unet: | 713 | if compile_unet: | 
| 706 | unet = torch.compile(unet, backend='hidet') | 714 | unet = torch.compile(unet, backend='hidet') | 
| 707 | # unet = torch.compile(unet) | 715 | # unet = torch.compile(unet, mode="reduce-overhead") | 
| 708 | 716 | ||
| 709 | callbacks = strategy.callbacks( | 717 | callbacks = strategy.callbacks( | 
| 710 | accelerator=accelerator, | 718 | accelerator=accelerator, | 
| @@ -727,6 +735,7 @@ def train( | |||
| 727 | prior_loss_weight, | 735 | prior_loss_weight, | 
| 728 | seed, | 736 | seed, | 
| 729 | offset_noise_strength, | 737 | offset_noise_strength, | 
| 738 | input_pertubation, | ||
| 730 | disc, | 739 | disc, | 
| 731 | min_snr_gamma, | 740 | min_snr_gamma, | 
| 732 | ) | 741 | ) | 
