diff options
author | Volpeon <git@volpeon.ink> | 2022-11-28 13:23:05 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-28 13:23:05 +0100 |
commit | b8ba49fe4c44aaaa30894e5abba22d3bbf94a562 (patch) | |
tree | 0f727ae52048f829852dc0dc21a33cf2dd2f904c | |
parent | Fixed and improved Textual Inversion training (diff) | |
download | textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.gz textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.bz2 textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.zip |
Fixed noise calculation for v-prediction
-rw-r--r-- | dreambooth.py | 21 | ||||
-rw-r--r-- | textual_inversion.py | 29 |
2 files changed, 37 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py index d15f1ee..e9f785c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -893,6 +893,18 @@ def main(): | |||
893 | ) | 893 | ) |
894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
895 | 895 | ||
896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
902 | target = alpha_t * noise - sigma_t * latents | ||
903 | else: | ||
904 | target = noise | ||
905 | |||
906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
907 | |||
896 | try: | 908 | try: |
897 | for epoch in range(num_epochs): | 909 | for epoch in range(num_epochs): |
898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -929,19 +941,20 @@ def main(): | |||
929 | 941 | ||
930 | if args.num_class_images != 0: | 942 | if args.num_class_images != 0: |
931 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
932 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
933 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
934 | 947 | ||
935 | # Compute instance loss | 948 | # Compute instance loss |
936 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 949 | loss = get_loss(noise_pred, noise, latents, timesteps) |
937 | 950 | ||
938 | # Compute prior loss | 951 | # Compute prior loss |
939 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
940 | 953 | ||
941 | # Add the prior loss to the instance loss. | 954 | # Add the prior loss to the instance loss. |
942 | loss = loss + args.prior_loss_weight * prior_loss | 955 | loss = loss + args.prior_loss_weight * prior_loss |
943 | else: | 956 | else: |
944 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 957 | loss = get_loss(noise_pred, noise, latents, timesteps) |
945 | 958 | ||
946 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
947 | 960 | ||
@@ -1034,7 +1047,7 @@ def main(): | |||
1034 | 1047 | ||
1035 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
1036 | 1049 | ||
1037 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1050 | loss = get_loss(noise_pred, noise, latents, timesteps) |
1038 | 1051 | ||
1039 | acc = (noise_pred == latents).float() | 1052 | acc = (noise_pred == latents).float() |
1040 | acc = acc.mean() | 1053 | acc = acc.mean() |
diff --git a/textual_inversion.py b/textual_inversion.py index 20b1617..fa7ae42 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -439,7 +439,7 @@ class Checkpointer: | |||
439 | with torch.autocast("cuda"), torch.inference_mode(): | 439 | with torch.autocast("cuda"), torch.inference_mode(): |
440 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 440 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
441 | all_samples = [] | 441 | all_samples = [] |
442 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 442 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
443 | file_path.parent.mkdir(parents=True, exist_ok=True) | 443 | file_path.parent.mkdir(parents=True, exist_ok=True) |
444 | 444 | ||
445 | data_enum = enumerate(data) | 445 | data_enum = enumerate(data) |
@@ -568,10 +568,6 @@ def main(): | |||
568 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 568 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
569 | token_embeds = text_encoder.get_input_embeddings().weight.data | 569 | token_embeds = text_encoder.get_input_embeddings().weight.data |
570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
571 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
572 | |||
573 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
574 | token_embeds[token_id] = embeddings | ||
575 | 571 | ||
576 | if args.resume_checkpoint is not None: | 572 | if args.resume_checkpoint is not None: |
577 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 573 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
@@ -817,6 +813,18 @@ def main(): | |||
817 | ) | 813 | ) |
818 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
819 | 815 | ||
816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
822 | target = alpha_t * noise - sigma_t * latents | ||
823 | else: | ||
824 | target = noise | ||
825 | |||
826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
827 | |||
820 | try: | 828 | try: |
821 | for epoch in range(num_epochs): | 829 | for epoch in range(num_epochs): |
822 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -854,19 +862,20 @@ def main(): | |||
854 | 862 | ||
855 | if args.num_class_images != 0: | 863 | if args.num_class_images != 0: |
856 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
857 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
858 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
859 | 868 | ||
860 | # Compute instance loss | 869 | # Compute instance loss |
861 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 870 | loss = get_loss(noise_pred, noise, latents, timesteps) |
862 | 871 | ||
863 | # Compute prior loss | 872 | # Compute prior loss |
864 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
865 | 874 | ||
866 | # Add the prior loss to the instance loss. | 875 | # Add the prior loss to the instance loss. |
867 | loss = loss + args.prior_loss_weight * prior_loss | 876 | loss = loss + args.prior_loss_weight * prior_loss |
868 | else: | 877 | else: |
869 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 878 | loss = get_loss(noise_pred, noise, latents, timesteps) |
870 | 879 | ||
871 | accelerator.backward(loss) | 880 | accelerator.backward(loss) |
872 | 881 | ||
@@ -922,6 +931,8 @@ def main(): | |||
922 | 931 | ||
923 | accelerator.wait_for_everyone() | 932 | accelerator.wait_for_everyone() |
924 | 933 | ||
934 | print(token_embeds[placeholder_token_id]) | ||
935 | |||
925 | text_encoder.eval() | 936 | text_encoder.eval() |
926 | val_loss = 0.0 | 937 | val_loss = 0.0 |
927 | 938 | ||
@@ -945,7 +956,7 @@ def main(): | |||
945 | 956 | ||
946 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
947 | 958 | ||
948 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 959 | loss = get_loss(noise_pred, noise, latents, timesteps) |
949 | 960 | ||
950 | loss = loss.detach().item() | 961 | loss = loss.detach().item() |
951 | val_loss += loss | 962 | val_loss += loss |