summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py49
1 files changed, 26 insertions, 23 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index fa7ae42..7ac9638 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -813,18 +813,6 @@ def main():
813 ) 813 )
814 global_progress_bar.set_description("Total progress") 814 global_progress_bar.set_description("Total progress")
815 815
816 def get_loss(noise_pred, noise, latents, timesteps):
817 if noise_scheduler.config.prediction_type == "v_prediction":
818 timesteps = timesteps.view(-1, 1, 1, 1)
819 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
820 alpha_t = torch.sqrt(alphas_cumprod)
821 sigma_t = torch.sqrt(1 - alphas_cumprod)
822 target = alpha_t * noise - sigma_t * latents
823 else:
824 target = noise
825
826 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
827
828 try: 816 try:
829 for epoch in range(num_epochs): 817 for epoch in range(num_epochs):
830 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 818 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -858,24 +846,31 @@ def main():
858 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 846 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
859 847
860 # Predict the noise residual 848 # Predict the noise residual
861 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 849 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
850
851 # Get the target for loss depending on the prediction type
852 if noise_scheduler.config.prediction_type == "epsilon":
853 target = noise
854 elif noise_scheduler.config.prediction_type == "v_prediction":
855 target = noise_scheduler.get_velocity(latents, noise, timesteps)
856 else:
857 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
862 858
863 if args.num_class_images != 0: 859 if args.num_class_images != 0:
864 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 860 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
865 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) 861 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
866 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 862 target, target_prior = torch.chunk(target, 2, dim=0)
867 noise, noise_prior = torch.chunk(noise, 2, dim=0)
868 863
869 # Compute instance loss 864 # Compute instance loss
870 loss = get_loss(noise_pred, noise, latents, timesteps) 865 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
871 866
872 # Compute prior loss 867 # Compute prior loss
873 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) 868 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
874 869
875 # Add the prior loss to the instance loss. 870 # Add the prior loss to the instance loss.
876 loss = loss + args.prior_loss_weight * prior_loss 871 loss = loss + args.prior_loss_weight * prior_loss
877 else: 872 else:
878 loss = get_loss(noise_pred, noise, latents, timesteps) 873 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
879 874
880 accelerator.backward(loss) 875 accelerator.backward(loss)
881 876
@@ -952,11 +947,19 @@ def main():
952 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 947 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
953 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 948 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
954 949
955 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 950 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
956 951
957 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 952 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
953
954 # Get the target for loss depending on the prediction type
955 if noise_scheduler.config.prediction_type == "epsilon":
956 target = noise
957 elif noise_scheduler.config.prediction_type == "v_prediction":
958 target = noise_scheduler.get_velocity(latents, noise, timesteps)
959 else:
960 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
958 961
959 loss = get_loss(noise_pred, noise, latents, timesteps) 962 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
960 963
961 loss = loss.detach().item() 964 loss = loss.detach().item()
962 val_loss += loss 965 val_loss += loss