From b8ba49fe4c44aaaa30894e5abba22d3bbf94a562 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 28 Nov 2022 13:23:05 +0100 Subject: Fixed noise calculation for v-prediction --- dreambooth.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index d15f1ee..e9f785c 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -893,6 +893,18 @@ def main(): ) global_progress_bar.set_description("Total progress") + def get_loss(noise_pred, noise, latents, timesteps): + if noise_scheduler.config.prediction_type == "v_prediction": + timesteps = timesteps.view(-1, 1, 1, 1) + alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] + alpha_t = torch.sqrt(alphas_cumprod) + sigma_t = torch.sqrt(1 - alphas_cumprod) + target = alpha_t * noise - sigma_t * latents + else: + target = noise + + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -929,19 +941,20 @@ def main(): if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + loss = get_loss(noise_pred, noise, latents, timesteps) # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = get_loss(noise_pred, noise, latents, timesteps) accelerator.backward(loss) @@ -1034,7 +1047,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = get_loss(noise_pred, noise, latents, timesteps) acc = (noise_pred == latents).float() acc = acc.mean() -- cgit v1.2.3-54-g00ecf