summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py57
-rw-r--r--textual_inversion.py49
2 files changed, 56 insertions, 50 deletions
diff --git a/dreambooth.py b/dreambooth.py
index e9f785c..49d4447 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -893,18 +893,6 @@ def main():
893 ) 893 )
894 global_progress_bar.set_description("Total progress") 894 global_progress_bar.set_description("Total progress")
895 895
896 def get_loss(noise_pred, noise, latents, timesteps):
897 if noise_scheduler.config.prediction_type == "v_prediction":
898 timesteps = timesteps.view(-1, 1, 1, 1)
899 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
900 alpha_t = torch.sqrt(alphas_cumprod)
901 sigma_t = torch.sqrt(1 - alphas_cumprod)
902 target = alpha_t * noise - sigma_t * latents
903 else:
904 target = noise
905
906 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
907
908 try: 896 try:
909 for epoch in range(num_epochs): 897 for epoch in range(num_epochs):
910 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 898 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -937,24 +925,31 @@ def main():
937 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 925 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
938 926
939 # Predict the noise residual 927 # Predict the noise residual
940 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 928 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
929
930 # Get the target for loss depending on the prediction type
931 if noise_scheduler.config.prediction_type == "epsilon":
932 target = noise
933 elif noise_scheduler.config.prediction_type == "v_prediction":
934 target = noise_scheduler.get_velocity(latents, noise, timesteps)
935 else:
936 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
941 937
942 if args.num_class_images != 0: 938 if args.num_class_images != 0:
943 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 939 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
944 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) 940 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
945 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 941 target, target_prior = torch.chunk(target, 2, dim=0)
946 noise, noise_prior = torch.chunk(noise, 2, dim=0)
947 942
948 # Compute instance loss 943 # Compute instance loss
949 loss = get_loss(noise_pred, noise, latents, timesteps) 944 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
950 945
951 # Compute prior loss 946 # Compute prior loss
952 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) 947 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
953 948
954 # Add the prior loss to the instance loss. 949 # Add the prior loss to the instance loss.
955 loss = loss + args.prior_loss_weight * prior_loss 950 loss = loss + args.prior_loss_weight * prior_loss
956 else: 951 else:
957 loss = get_loss(noise_pred, noise, latents, timesteps) 952 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
958 953
959 accelerator.backward(loss) 954 accelerator.backward(loss)
960 955
@@ -985,7 +980,7 @@ def main():
985 ema_unet.step(unet) 980 ema_unet.step(unet)
986 optimizer.zero_grad(set_to_none=True) 981 optimizer.zero_grad(set_to_none=True)
987 982
988 acc = (noise_pred == latents).float() 983 acc = (model_pred == latents).float()
989 acc = acc.mean() 984 acc = acc.mean()
990 985
991 total_loss += loss.item() 986 total_loss += loss.item()
@@ -1006,8 +1001,8 @@ def main():
1006 sample_checkpoint = True 1001 sample_checkpoint = True
1007 1002
1008 logs = { 1003 logs = {
1009 "train/loss": total_loss / global_step, 1004 "train/loss": total_loss / global_step if global_step != 0 else 0,
1010 "train/acc": total_acc / global_step, 1005 "train/acc": total_acc / global_step if global_step != 0 else 0,
1011 "train/cur_loss": loss.item(), 1006 "train/cur_loss": loss.item(),
1012 "train/cur_acc": acc.item(), 1007 "train/cur_acc": acc.item(),
1013 "lr/unet": lr_scheduler.get_last_lr()[0], 1008 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -1043,13 +1038,21 @@ def main():
1043 1038
1044 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 1039 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
1045 1040
1046 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1041 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1047 1042
1048 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 1043 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
1044
1045 # Get the target for loss depending on the prediction type
1046 if noise_scheduler.config.prediction_type == "epsilon":
1047 target = noise
1048 elif noise_scheduler.config.prediction_type == "v_prediction":
1049 target = noise_scheduler.get_velocity(latents, noise, timesteps)
1050 else:
1051 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1049 1052
1050 loss = get_loss(noise_pred, noise, latents, timesteps) 1053 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1051 1054
1052 acc = (noise_pred == latents).float() 1055 acc = (model_pred == latents).float()
1053 acc = acc.mean() 1056 acc = acc.mean()
1054 1057
1055 total_loss_val += loss.item() 1058 total_loss_val += loss.item()
diff --git a/textual_inversion.py b/textual_inversion.py
index fa7ae42..7ac9638 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -813,18 +813,6 @@ def main():
813 ) 813 )
814 global_progress_bar.set_description("Total progress") 814 global_progress_bar.set_description("Total progress")
815 815
816 def get_loss(noise_pred, noise, latents, timesteps):
817 if noise_scheduler.config.prediction_type == "v_prediction":
818 timesteps = timesteps.view(-1, 1, 1, 1)
819 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
820 alpha_t = torch.sqrt(alphas_cumprod)
821 sigma_t = torch.sqrt(1 - alphas_cumprod)
822 target = alpha_t * noise - sigma_t * latents
823 else:
824 target = noise
825
826 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
827
828 try: 816 try:
829 for epoch in range(num_epochs): 817 for epoch in range(num_epochs):
830 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 818 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -858,24 +846,31 @@ def main():
858 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 846 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
859 847
860 # Predict the noise residual 848 # Predict the noise residual
861 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 849 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
850
851 # Get the target for loss depending on the prediction type
852 if noise_scheduler.config.prediction_type == "epsilon":
853 target = noise
854 elif noise_scheduler.config.prediction_type == "v_prediction":
855 target = noise_scheduler.get_velocity(latents, noise, timesteps)
856 else:
857 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
862 858
863 if args.num_class_images != 0: 859 if args.num_class_images != 0:
864 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 860 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
865 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) 861 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
866 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 862 target, target_prior = torch.chunk(target, 2, dim=0)
867 noise, noise_prior = torch.chunk(noise, 2, dim=0)
868 863
869 # Compute instance loss 864 # Compute instance loss
870 loss = get_loss(noise_pred, noise, latents, timesteps) 865 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
871 866
872 # Compute prior loss 867 # Compute prior loss
873 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) 868 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
874 869
875 # Add the prior loss to the instance loss. 870 # Add the prior loss to the instance loss.
876 loss = loss + args.prior_loss_weight * prior_loss 871 loss = loss + args.prior_loss_weight * prior_loss
877 else: 872 else:
878 loss = get_loss(noise_pred, noise, latents, timesteps) 873 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
879 874
880 accelerator.backward(loss) 875 accelerator.backward(loss)
881 876
@@ -952,11 +947,19 @@ def main():
952 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 947 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
953 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 948 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
954 949
955 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 950 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
956 951
957 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 952 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
953
954 # Get the target for loss depending on the prediction type
955 if noise_scheduler.config.prediction_type == "epsilon":
956 target = noise
957 elif noise_scheduler.config.prediction_type == "v_prediction":
958 target = noise_scheduler.get_velocity(latents, noise, timesteps)
959 else:
960 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
958 961
959 loss = get_loss(noise_pred, noise, latents, timesteps) 962 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
960 963
961 loss = loss.detach().item() 964 loss = loss.detach().item()
962 val_loss += loss 965 val_loss += loss