From 1386c7badd2930f8a8f8f649216a25f3809a4d96 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 28 Nov 2022 20:27:56 +0100 Subject: Adjusted training to upstream --- dreambooth.py | 57 +++++++++++++++++++++++++++------------------------- textual_inversion.py | 49 +++++++++++++++++++++++--------------------- 2 files changed, 56 insertions(+), 50 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index e9f785c..49d4447 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -893,18 +893,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def get_loss(noise_pred, noise, latents, timesteps): - if noise_scheduler.config.prediction_type == "v_prediction": - timesteps = timesteps.view(-1, 1, 1, 1) - alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] - alpha_t = torch.sqrt(alphas_cumprod) - sigma_t = torch.sqrt(1 - alphas_cumprod) - target = alpha_t * noise - sigma_t * latents - else: - target = noise - - return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -937,24 +925,31 @@ def main(): encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.num_class_images != 0: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) @@ -985,7 +980,7 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - acc = (noise_pred == latents).float() + acc = (model_pred == latents).float() acc = acc.mean() total_loss += loss.item() @@ -1006,8 +1001,8 @@ def main(): sample_checkpoint = True logs = { - "train/loss": total_loss / global_step, - "train/acc": total_acc / global_step, + "train/loss": total_loss / global_step if global_step != 0 else 0, + "train/acc": total_acc / global_step if global_step != 0 else 0, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr/unet": lr_scheduler.get_last_lr()[0], @@ -1043,13 +1038,21 @@ def main(): encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (noise_pred == latents).float() + acc = (model_pred == latents).float() acc = acc.mean() total_loss_val += loss.item() diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -813,18 +813,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def get_loss(noise_pred, noise, latents, timesteps): - if noise_scheduler.config.prediction_type == "v_prediction": - timesteps = timesteps.view(-1, 1, 1, 1) - alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] - alpha_t = torch.sqrt(alphas_cumprod) - sigma_t = torch.sqrt(1 - alphas_cumprod) - target = alpha_t * noise - sigma_t * latents - else: - target = noise - - return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -858,24 +846,31 @@ def main(): encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.num_class_images != 0: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) @@ -952,11 +947,19 @@ def main(): encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-70-g09d2