From f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 13:46:36 +0100 Subject: Fixed SNR weighting, re-enabled xformers --- training/functional.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 8dc2b9f..43ee356 100644 --- a/training/functional.py +++ b/training/functional.py @@ -251,6 +251,25 @@ def add_placeholder_tokens( return placeholder_token_ids, initializer_token_ids +def snr_weight(noisy_latents, latents, gamma): + if gamma: + sigma = torch.sub(noisy_latents, latents) + zeros = torch.zeros_like(sigma) + alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + snr = torch.div(alpha_mean_sq, sigma_mean_sq) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() + return snr_weight + + return torch.tensor( + [1], + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + ) + + def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, @@ -308,21 +327,13 @@ def loss_step( model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type - alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() - snr = alpha_t / (1 - alpha_t) - min_snr = snr.clamp(max=min_snr_gamma) - if noise_scheduler.config.prediction_type == "epsilon": target = noise - loss_weight = min_snr / snr elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) - loss_weight = min_snr / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss_weight = loss_weight[..., None, None, None] - if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -339,7 +350,11 @@ def loss_step( else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = (loss_weight * loss).mean([1, 2, 3]).mean() + loss = loss.mean([1, 2, 3]) + + loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) + loss = (loss_weight * loss).mean() + acc = (model_pred == target).float().mean() return loss, acc, bsz @@ -412,7 +427,7 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0 and epoch != 0: + if epoch % sample_frequency == 0: local_progress_bar.clear() global_progress_bar.clear() -- cgit v1.2.3-54-g00ecf