summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
committerVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
commitb8ba49fe4c44aaaa30894e5abba22d3bbf94a562 (patch)
tree0f727ae52048f829852dc0dc21a33cf2dd2f904c
parentFixed and improved Textual Inversion training (diff)
downloadtextual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.gz
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.bz2
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.zip
Fixed noise calculation for v-prediction
-rw-r--r--dreambooth.py21
-rw-r--r--textual_inversion.py29
2 files changed, 37 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py
index d15f1ee..e9f785c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -893,6 +893,18 @@ def main():
893 ) 893 )
894 global_progress_bar.set_description("Total progress") 894 global_progress_bar.set_description("Total progress")
895 895
896 def get_loss(noise_pred, noise, latents, timesteps):
897 if noise_scheduler.config.prediction_type == "v_prediction":
898 timesteps = timesteps.view(-1, 1, 1, 1)
899 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
900 alpha_t = torch.sqrt(alphas_cumprod)
901 sigma_t = torch.sqrt(1 - alphas_cumprod)
902 target = alpha_t * noise - sigma_t * latents
903 else:
904 target = noise
905
906 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
907
896 try: 908 try:
897 for epoch in range(num_epochs): 909 for epoch in range(num_epochs):
898 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 910 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -929,19 +941,20 @@ def main():
929 941
930 if args.num_class_images != 0: 942 if args.num_class_images != 0:
931 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 943 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
944 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0)
932 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 945 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
933 noise, noise_prior = torch.chunk(noise, 2, dim=0) 946 noise, noise_prior = torch.chunk(noise, 2, dim=0)
934 947
935 # Compute instance loss 948 # Compute instance loss
936 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 949 loss = get_loss(noise_pred, noise, latents, timesteps)
937 950
938 # Compute prior loss 951 # Compute prior loss
939 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 952 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps)
940 953
941 # Add the prior loss to the instance loss. 954 # Add the prior loss to the instance loss.
942 loss = loss + args.prior_loss_weight * prior_loss 955 loss = loss + args.prior_loss_weight * prior_loss
943 else: 956 else:
944 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 957 loss = get_loss(noise_pred, noise, latents, timesteps)
945 958
946 accelerator.backward(loss) 959 accelerator.backward(loss)
947 960
@@ -1034,7 +1047,7 @@ def main():
1034 1047
1035 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 1048 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
1036 1049
1037 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 1050 loss = get_loss(noise_pred, noise, latents, timesteps)
1038 1051
1039 acc = (noise_pred == latents).float() 1052 acc = (noise_pred == latents).float()
1040 acc = acc.mean() 1053 acc = acc.mean()
diff --git a/textual_inversion.py b/textual_inversion.py
index 20b1617..fa7ae42 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -439,7 +439,7 @@ class Checkpointer:
439 with torch.autocast("cuda"), torch.inference_mode(): 439 with torch.autocast("cuda"), torch.inference_mode():
440 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 440 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
441 all_samples = [] 441 all_samples = []
442 file_path = samples_path.joinpath(pool, f"step_{step}.png") 442 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
443 file_path.parent.mkdir(parents=True, exist_ok=True) 443 file_path.parent.mkdir(parents=True, exist_ok=True)
444 444
445 data_enum = enumerate(data) 445 data_enum = enumerate(data)
@@ -568,10 +568,6 @@ def main():
568 # Initialise the newly added placeholder token with the embeddings of the initializer token 568 # Initialise the newly added placeholder token with the embeddings of the initializer token
569 token_embeds = text_encoder.get_input_embeddings().weight.data 569 token_embeds = text_encoder.get_input_embeddings().weight.data
570 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 570 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
571 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
572
573 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
574 token_embeds[token_id] = embeddings
575 571
576 if args.resume_checkpoint is not None: 572 if args.resume_checkpoint is not None:
577 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] 573 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
@@ -817,6 +813,18 @@ def main():
817 ) 813 )
818 global_progress_bar.set_description("Total progress") 814 global_progress_bar.set_description("Total progress")
819 815
816 def get_loss(noise_pred, noise, latents, timesteps):
817 if noise_scheduler.config.prediction_type == "v_prediction":
818 timesteps = timesteps.view(-1, 1, 1, 1)
819 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
820 alpha_t = torch.sqrt(alphas_cumprod)
821 sigma_t = torch.sqrt(1 - alphas_cumprod)
822 target = alpha_t * noise - sigma_t * latents
823 else:
824 target = noise
825
826 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
827
820 try: 828 try:
821 for epoch in range(num_epochs): 829 for epoch in range(num_epochs):
822 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 830 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -854,19 +862,20 @@ def main():
854 862
855 if args.num_class_images != 0: 863 if args.num_class_images != 0:
856 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 864 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
865 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0)
857 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 866 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
858 noise, noise_prior = torch.chunk(noise, 2, dim=0) 867 noise, noise_prior = torch.chunk(noise, 2, dim=0)
859 868
860 # Compute instance loss 869 # Compute instance loss
861 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 870 loss = get_loss(noise_pred, noise, latents, timesteps)
862 871
863 # Compute prior loss 872 # Compute prior loss
864 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 873 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps)
865 874
866 # Add the prior loss to the instance loss. 875 # Add the prior loss to the instance loss.
867 loss = loss + args.prior_loss_weight * prior_loss 876 loss = loss + args.prior_loss_weight * prior_loss
868 else: 877 else:
869 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 878 loss = get_loss(noise_pred, noise, latents, timesteps)
870 879
871 accelerator.backward(loss) 880 accelerator.backward(loss)
872 881
@@ -922,6 +931,8 @@ def main():
922 931
923 accelerator.wait_for_everyone() 932 accelerator.wait_for_everyone()
924 933
934 print(token_embeds[placeholder_token_id])
935
925 text_encoder.eval() 936 text_encoder.eval()
926 val_loss = 0.0 937 val_loss = 0.0
927 938
@@ -945,7 +956,7 @@ def main():
945 956
946 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 957 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
947 958
948 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 959 loss = get_loss(noise_pred, noise, latents, timesteps)
949 960
950 loss = loss.detach().item() 961 loss = loss.detach().item()
951 val_loss += loss 962 val_loss += loss