diff options
author | Volpeon <git@volpeon.ink> | 2022-11-28 20:27:56 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-28 20:27:56 +0100 |
commit | 1386c7badd2930f8a8f8f649216a25f3809a4d96 (patch) | |
tree | 684b487151be99b8dde8848a2886c0aae3a8d017 /textual_inversion.py | |
parent | Fixed noise calculation for v-prediction (diff) | |
download | textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.gz textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.bz2 textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.zip |
Adjusted training to upstream
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -813,18 +813,6 @@ def main(): | |||
813 | ) | 813 | ) |
814 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
815 | 815 | ||
816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
822 | target = alpha_t * noise - sigma_t * latents | ||
823 | else: | ||
824 | target = noise | ||
825 | |||
826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
827 | |||
828 | try: | 816 | try: |
829 | for epoch in range(num_epochs): | 817 | for epoch in range(num_epochs): |
830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 818 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -858,24 +846,31 @@ def main(): | |||
858 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 846 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
859 | 847 | ||
860 | # Predict the noise residual | 848 | # Predict the noise residual |
861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 849 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
850 | |||
851 | # Get the target for loss depending on the prediction type | ||
852 | if noise_scheduler.config.prediction_type == "epsilon": | ||
853 | target = noise | ||
854 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
855 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
856 | else: | ||
857 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
862 | 858 | ||
863 | if args.num_class_images != 0: | 859 | if args.num_class_images != 0: |
864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 860 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 861 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 862 | target, target_prior = torch.chunk(target, 2, dim=0) |
867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
868 | 863 | ||
869 | # Compute instance loss | 864 | # Compute instance loss |
870 | loss = get_loss(noise_pred, noise, latents, timesteps) | 865 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
871 | 866 | ||
872 | # Compute prior loss | 867 | # Compute prior loss |
873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 868 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
874 | 869 | ||
875 | # Add the prior loss to the instance loss. | 870 | # Add the prior loss to the instance loss. |
876 | loss = loss + args.prior_loss_weight * prior_loss | 871 | loss = loss + args.prior_loss_weight * prior_loss |
877 | else: | 872 | else: |
878 | loss = get_loss(noise_pred, noise, latents, timesteps) | 873 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
879 | 874 | ||
880 | accelerator.backward(loss) | 875 | accelerator.backward(loss) |
881 | 876 | ||
@@ -952,11 +947,19 @@ def main(): | |||
952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 947 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 948 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
954 | 949 | ||
955 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 950 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
956 | 951 | ||
957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 952 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) |
953 | |||
954 | # Get the target for loss depending on the prediction type | ||
955 | if noise_scheduler.config.prediction_type == "epsilon": | ||
956 | target = noise | ||
957 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
958 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
959 | else: | ||
960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
958 | 961 | ||
959 | loss = get_loss(noise_pred, noise, latents, timesteps) | 962 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
960 | 963 | ||
961 | loss = loss.detach().item() | 964 | loss = loss.detach().item() |
962 | val_loss += loss | 965 | val_loss += loss |