diff options
-rw-r--r-- | dreambooth.py | 57 | ||||
-rw-r--r-- | textual_inversion.py | 49 |
2 files changed, 56 insertions, 50 deletions
diff --git a/dreambooth.py b/dreambooth.py index e9f785c..49d4447 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -893,18 +893,6 @@ def main(): | |||
893 | ) | 893 | ) |
894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
895 | 895 | ||
896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
902 | target = alpha_t * noise - sigma_t * latents | ||
903 | else: | ||
904 | target = noise | ||
905 | |||
906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
907 | |||
908 | try: | 896 | try: |
909 | for epoch in range(num_epochs): | 897 | for epoch in range(num_epochs): |
910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -937,24 +925,31 @@ def main(): | |||
937 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 925 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
938 | 926 | ||
939 | # Predict the noise residual | 927 | # Predict the noise residual |
940 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 928 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
929 | |||
930 | # Get the target for loss depending on the prediction type | ||
931 | if noise_scheduler.config.prediction_type == "epsilon": | ||
932 | target = noise | ||
933 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
934 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
935 | else: | ||
936 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
941 | 937 | ||
942 | if args.num_class_images != 0: | 938 | if args.num_class_images != 0: |
943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 939 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 940 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 941 | target, target_prior = torch.chunk(target, 2, dim=0) |
946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
947 | 942 | ||
948 | # Compute instance loss | 943 | # Compute instance loss |
949 | loss = get_loss(noise_pred, noise, latents, timesteps) | 944 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
950 | 945 | ||
951 | # Compute prior loss | 946 | # Compute prior loss |
952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 947 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
953 | 948 | ||
954 | # Add the prior loss to the instance loss. | 949 | # Add the prior loss to the instance loss. |
955 | loss = loss + args.prior_loss_weight * prior_loss | 950 | loss = loss + args.prior_loss_weight * prior_loss |
956 | else: | 951 | else: |
957 | loss = get_loss(noise_pred, noise, latents, timesteps) | 952 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
958 | 953 | ||
959 | accelerator.backward(loss) | 954 | accelerator.backward(loss) |
960 | 955 | ||
@@ -985,7 +980,7 @@ def main(): | |||
985 | ema_unet.step(unet) | 980 | ema_unet.step(unet) |
986 | optimizer.zero_grad(set_to_none=True) | 981 | optimizer.zero_grad(set_to_none=True) |
987 | 982 | ||
988 | acc = (noise_pred == latents).float() | 983 | acc = (model_pred == latents).float() |
989 | acc = acc.mean() | 984 | acc = acc.mean() |
990 | 985 | ||
991 | total_loss += loss.item() | 986 | total_loss += loss.item() |
@@ -1006,8 +1001,8 @@ def main(): | |||
1006 | sample_checkpoint = True | 1001 | sample_checkpoint = True |
1007 | 1002 | ||
1008 | logs = { | 1003 | logs = { |
1009 | "train/loss": total_loss / global_step, | 1004 | "train/loss": total_loss / global_step if global_step != 0 else 0, |
1010 | "train/acc": total_acc / global_step, | 1005 | "train/acc": total_acc / global_step if global_step != 0 else 0, |
1011 | "train/cur_loss": loss.item(), | 1006 | "train/cur_loss": loss.item(), |
1012 | "train/cur_acc": acc.item(), | 1007 | "train/cur_acc": acc.item(), |
1013 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1008 | "lr/unet": lr_scheduler.get_last_lr()[0], |
@@ -1043,13 +1038,21 @@ def main(): | |||
1043 | 1038 | ||
1044 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 1039 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
1045 | 1040 | ||
1046 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1041 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
1047 | 1042 | ||
1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1043 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) |
1044 | |||
1045 | # Get the target for loss depending on the prediction type | ||
1046 | if noise_scheduler.config.prediction_type == "epsilon": | ||
1047 | target = noise | ||
1048 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
1049 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
1050 | else: | ||
1051 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
1049 | 1052 | ||
1050 | loss = get_loss(noise_pred, noise, latents, timesteps) | 1053 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
1051 | 1054 | ||
1052 | acc = (noise_pred == latents).float() | 1055 | acc = (model_pred == latents).float() |
1053 | acc = acc.mean() | 1056 | acc = acc.mean() |
1054 | 1057 | ||
1055 | total_loss_val += loss.item() | 1058 | total_loss_val += loss.item() |
diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -813,18 +813,6 @@ def main(): | |||
813 | ) | 813 | ) |
814 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
815 | 815 | ||
816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
822 | target = alpha_t * noise - sigma_t * latents | ||
823 | else: | ||
824 | target = noise | ||
825 | |||
826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
827 | |||
828 | try: | 816 | try: |
829 | for epoch in range(num_epochs): | 817 | for epoch in range(num_epochs): |
830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 818 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -858,24 +846,31 @@ def main(): | |||
858 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 846 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
859 | 847 | ||
860 | # Predict the noise residual | 848 | # Predict the noise residual |
861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 849 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
850 | |||
851 | # Get the target for loss depending on the prediction type | ||
852 | if noise_scheduler.config.prediction_type == "epsilon": | ||
853 | target = noise | ||
854 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
855 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
856 | else: | ||
857 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
862 | 858 | ||
863 | if args.num_class_images != 0: | 859 | if args.num_class_images != 0: |
864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 860 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 861 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 862 | target, target_prior = torch.chunk(target, 2, dim=0) |
867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
868 | 863 | ||
869 | # Compute instance loss | 864 | # Compute instance loss |
870 | loss = get_loss(noise_pred, noise, latents, timesteps) | 865 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
871 | 866 | ||
872 | # Compute prior loss | 867 | # Compute prior loss |
873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 868 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
874 | 869 | ||
875 | # Add the prior loss to the instance loss. | 870 | # Add the prior loss to the instance loss. |
876 | loss = loss + args.prior_loss_weight * prior_loss | 871 | loss = loss + args.prior_loss_weight * prior_loss |
877 | else: | 872 | else: |
878 | loss = get_loss(noise_pred, noise, latents, timesteps) | 873 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
879 | 874 | ||
880 | accelerator.backward(loss) | 875 | accelerator.backward(loss) |
881 | 876 | ||
@@ -952,11 +947,19 @@ def main(): | |||
952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 947 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 948 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
954 | 949 | ||
955 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 950 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
956 | 951 | ||
957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 952 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) |
953 | |||
954 | # Get the target for loss depending on the prediction type | ||
955 | if noise_scheduler.config.prediction_type == "epsilon": | ||
956 | target = noise | ||
957 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
958 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
959 | else: | ||
960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
958 | 961 | ||
959 | loss = get_loss(noise_pred, noise, latents, timesteps) | 962 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
960 | 963 | ||
961 | loss = loss.detach().item() | 964 | loss = loss.detach().item() |
962 | val_loss += loss | 965 | val_loss += loss |