From 6c83a18aa8cf1d0d2a972bc8393584eb61b9deac Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 19 Mar 2023 14:37:07 +0100 Subject: Restore min SNR --- training/functional.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 15b95ba..8dc2b9f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -261,7 +261,8 @@ def loss_step( seed: int, step: int, batch: dict[str, Any], - eval: bool = False + eval: bool = False, + min_snr_gamma: int = 5 ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() @@ -307,23 +308,21 @@ def loss_step( model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type + alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() + snr = alpha_t / (1 - alpha_t) + min_snr = snr.clamp(max=min_snr_gamma) + if noise_scheduler.config.prediction_type == "epsilon": target = noise - - snr_weights = 1 + loss_weight = min_snr / snr elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) - - p2_gamma = 1 - p2_k = 1 - - alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() - snr = 1.0 / (1 - alpha_t) - 1 - snr_weights = 1 / (p2_k + snr) ** p2_gamma - snr_weights = snr_weights[..., None, None, None] + loss_weight = min_snr / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss_weight = loss_weight[..., None, None, None] + if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -340,7 +339,7 @@ def loss_step( else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = (snr_weights * loss).mean([1, 2, 3]).mean() + loss = (loss_weight * loss).mean([1, 2, 3]).mean() acc = (model_pred == target).float().mean() return loss, acc, bsz @@ -413,7 +412,7 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0: + if epoch % sample_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() -- cgit v1.2.3-54-g00ecf