summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
committerVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
commitb8ba49fe4c44aaaa30894e5abba22d3bbf94a562 (patch)
tree0f727ae52048f829852dc0dc21a33cf2dd2f904c /textual_inversion.py
parentFixed and improved Textual Inversion training (diff)
downloadtextual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.gz
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.bz2
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.zip
Fixed noise calculation for v-prediction
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py29
1 files changed, 20 insertions, 9 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 20b1617..fa7ae42 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -439,7 +439,7 @@ class Checkpointer:
439 with torch.autocast("cuda"), torch.inference_mode(): 439 with torch.autocast("cuda"), torch.inference_mode():
440 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 440 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
441 all_samples = [] 441 all_samples = []
442 file_path = samples_path.joinpath(pool, f"step_{step}.png") 442 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
443 file_path.parent.mkdir(parents=True, exist_ok=True) 443 file_path.parent.mkdir(parents=True, exist_ok=True)
444 444
445 data_enum = enumerate(data) 445 data_enum = enumerate(data)
@@ -568,10 +568,6 @@ def main():
568 # Initialise the newly added placeholder token with the embeddings of the initializer token 568 # Initialise the newly added placeholder token with the embeddings of the initializer token
569 token_embeds = text_encoder.get_input_embeddings().weight.data 569 token_embeds = text_encoder.get_input_embeddings().weight.data
570 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 570 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
571 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
572
573 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
574 token_embeds[token_id] = embeddings
575 571
576 if args.resume_checkpoint is not None: 572 if args.resume_checkpoint is not None:
577 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] 573 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
@@ -817,6 +813,18 @@ def main():
817 ) 813 )
818 global_progress_bar.set_description("Total progress") 814 global_progress_bar.set_description("Total progress")
819 815
816 def get_loss(noise_pred, noise, latents, timesteps):
817 if noise_scheduler.config.prediction_type == "v_prediction":
818 timesteps = timesteps.view(-1, 1, 1, 1)
819 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
820 alpha_t = torch.sqrt(alphas_cumprod)
821 sigma_t = torch.sqrt(1 - alphas_cumprod)
822 target = alpha_t * noise - sigma_t * latents
823 else:
824 target = noise
825
826 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
827
820 try: 828 try:
821 for epoch in range(num_epochs): 829 for epoch in range(num_epochs):
822 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 830 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -854,19 +862,20 @@ def main():
854 862
855 if args.num_class_images != 0: 863 if args.num_class_images != 0:
856 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 864 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
865 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0)
857 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 866 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
858 noise, noise_prior = torch.chunk(noise, 2, dim=0) 867 noise, noise_prior = torch.chunk(noise, 2, dim=0)
859 868
860 # Compute instance loss 869 # Compute instance loss
861 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 870 loss = get_loss(noise_pred, noise, latents, timesteps)
862 871
863 # Compute prior loss 872 # Compute prior loss
864 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 873 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps)
865 874
866 # Add the prior loss to the instance loss. 875 # Add the prior loss to the instance loss.
867 loss = loss + args.prior_loss_weight * prior_loss 876 loss = loss + args.prior_loss_weight * prior_loss
868 else: 877 else:
869 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 878 loss = get_loss(noise_pred, noise, latents, timesteps)
870 879
871 accelerator.backward(loss) 880 accelerator.backward(loss)
872 881
@@ -922,6 +931,8 @@ def main():
922 931
923 accelerator.wait_for_everyone() 932 accelerator.wait_for_everyone()
924 933
934 print(token_embeds[placeholder_token_id])
935
925 text_encoder.eval() 936 text_encoder.eval()
926 val_loss = 0.0 937 val_loss = 0.0
927 938
@@ -945,7 +956,7 @@ def main():
945 956
946 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 957 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
947 958
948 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 959 loss = get_loss(noise_pred, noise, latents, timesteps)
949 960
950 loss = loss.detach().item() 961 loss = loss.detach().item()
951 val_loss += loss 962 val_loss += loss