From 1386c7badd2930f8a8f8f649216a25f3809a4d96 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 28 Nov 2022 20:27:56 +0100 Subject: Adjusted training to upstream --- textual_inversion.py | 49 ++++++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 23 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -813,18 +813,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def get_loss(noise_pred, noise, latents, timesteps): - if noise_scheduler.config.prediction_type == "v_prediction": - timesteps = timesteps.view(-1, 1, 1, 1) - alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] - alpha_t = torch.sqrt(alphas_cumprod) - sigma_t = torch.sqrt(1 - alphas_cumprod) - target = alpha_t * noise - sigma_t * latents - else: - target = noise - - return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -858,24 +846,31 @@ def main(): encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.num_class_images != 0: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) @@ -952,11 +947,19 @@ def main(): encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = get_loss(noise_pred, noise, latents, timesteps) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf