diff options
| -rw-r--r-- | dreambooth.py | 21 | ||||
| -rw-r--r-- | textual_inversion.py | 29 |
2 files changed, 37 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py index d15f1ee..e9f785c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -893,6 +893,18 @@ def main(): | |||
| 893 | ) | 893 | ) |
| 894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
| 895 | 895 | ||
| 896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 902 | target = alpha_t * noise - sigma_t * latents | ||
| 903 | else: | ||
| 904 | target = noise | ||
| 905 | |||
| 906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 907 | |||
| 896 | try: | 908 | try: |
| 897 | for epoch in range(num_epochs): | 909 | for epoch in range(num_epochs): |
| 898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -929,19 +941,20 @@ def main(): | |||
| 929 | 941 | ||
| 930 | if args.num_class_images != 0: | 942 | if args.num_class_images != 0: |
| 931 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 932 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 933 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 934 | 947 | ||
| 935 | # Compute instance loss | 948 | # Compute instance loss |
| 936 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 949 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 937 | 950 | ||
| 938 | # Compute prior loss | 951 | # Compute prior loss |
| 939 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
| 940 | 953 | ||
| 941 | # Add the prior loss to the instance loss. | 954 | # Add the prior loss to the instance loss. |
| 942 | loss = loss + args.prior_loss_weight * prior_loss | 955 | loss = loss + args.prior_loss_weight * prior_loss |
| 943 | else: | 956 | else: |
| 944 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 957 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 945 | 958 | ||
| 946 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
| 947 | 960 | ||
| @@ -1034,7 +1047,7 @@ def main(): | |||
| 1034 | 1047 | ||
| 1035 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 1036 | 1049 | ||
| 1037 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1050 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 1038 | 1051 | ||
| 1039 | acc = (noise_pred == latents).float() | 1052 | acc = (noise_pred == latents).float() |
| 1040 | acc = acc.mean() | 1053 | acc = acc.mean() |
diff --git a/textual_inversion.py b/textual_inversion.py index 20b1617..fa7ae42 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -439,7 +439,7 @@ class Checkpointer: | |||
| 439 | with torch.autocast("cuda"), torch.inference_mode(): | 439 | with torch.autocast("cuda"), torch.inference_mode(): |
| 440 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 440 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 441 | all_samples = [] | 441 | all_samples = [] |
| 442 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 442 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 443 | file_path.parent.mkdir(parents=True, exist_ok=True) | 443 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 444 | 444 | ||
| 445 | data_enum = enumerate(data) | 445 | data_enum = enumerate(data) |
| @@ -568,10 +568,6 @@ def main(): | |||
| 568 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 568 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 569 | token_embeds = text_encoder.get_input_embeddings().weight.data | 569 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 571 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 572 | |||
| 573 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 574 | token_embeds[token_id] = embeddings | ||
| 575 | 571 | ||
| 576 | if args.resume_checkpoint is not None: | 572 | if args.resume_checkpoint is not None: |
| 577 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 573 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
| @@ -817,6 +813,18 @@ def main(): | |||
| 817 | ) | 813 | ) |
| 818 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
| 819 | 815 | ||
| 816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 822 | target = alpha_t * noise - sigma_t * latents | ||
| 823 | else: | ||
| 824 | target = noise | ||
| 825 | |||
| 826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 827 | |||
| 820 | try: | 828 | try: |
| 821 | for epoch in range(num_epochs): | 829 | for epoch in range(num_epochs): |
| 822 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -854,19 +862,20 @@ def main(): | |||
| 854 | 862 | ||
| 855 | if args.num_class_images != 0: | 863 | if args.num_class_images != 0: |
| 856 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 857 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 858 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 859 | 868 | ||
| 860 | # Compute instance loss | 869 | # Compute instance loss |
| 861 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 870 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 862 | 871 | ||
| 863 | # Compute prior loss | 872 | # Compute prior loss |
| 864 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
| 865 | 874 | ||
| 866 | # Add the prior loss to the instance loss. | 875 | # Add the prior loss to the instance loss. |
| 867 | loss = loss + args.prior_loss_weight * prior_loss | 876 | loss = loss + args.prior_loss_weight * prior_loss |
| 868 | else: | 877 | else: |
| 869 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 878 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 870 | 879 | ||
| 871 | accelerator.backward(loss) | 880 | accelerator.backward(loss) |
| 872 | 881 | ||
| @@ -922,6 +931,8 @@ def main(): | |||
| 922 | 931 | ||
| 923 | accelerator.wait_for_everyone() | 932 | accelerator.wait_for_everyone() |
| 924 | 933 | ||
| 934 | print(token_embeds[placeholder_token_id]) | ||
| 935 | |||
| 925 | text_encoder.eval() | 936 | text_encoder.eval() |
| 926 | val_loss = 0.0 | 937 | val_loss = 0.0 |
| 927 | 938 | ||
| @@ -945,7 +956,7 @@ def main(): | |||
| 945 | 956 | ||
| 946 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 947 | 958 | ||
| 948 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 959 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 949 | 960 | ||
| 950 | loss = loss.detach().item() | 961 | loss = loss.detach().item() |
| 951 | val_loss += loss | 962 | val_loss += loss |
