diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-28 20:27:56 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-28 20:27:56 +0100 |
| commit | 1386c7badd2930f8a8f8f649216a25f3809a4d96 (patch) | |
| tree | 684b487151be99b8dde8848a2886c0aae3a8d017 /textual_inversion.py | |
| parent | Fixed noise calculation for v-prediction (diff) | |
| download | textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.gz textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.bz2 textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.zip | |
Adjusted training to upstream
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -813,18 +813,6 @@ def main(): | |||
| 813 | ) | 813 | ) |
| 814 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
| 815 | 815 | ||
| 816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 822 | target = alpha_t * noise - sigma_t * latents | ||
| 823 | else: | ||
| 824 | target = noise | ||
| 825 | |||
| 826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 827 | |||
| 828 | try: | 816 | try: |
| 829 | for epoch in range(num_epochs): | 817 | for epoch in range(num_epochs): |
| 830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 818 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -858,24 +846,31 @@ def main(): | |||
| 858 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 846 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 859 | 847 | ||
| 860 | # Predict the noise residual | 848 | # Predict the noise residual |
| 861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 849 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 850 | |||
| 851 | # Get the target for loss depending on the prediction type | ||
| 852 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 853 | target = noise | ||
| 854 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 855 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 856 | else: | ||
| 857 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 862 | 858 | ||
| 863 | if args.num_class_images != 0: | 859 | if args.num_class_images != 0: |
| 864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 860 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 861 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 862 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 868 | 863 | ||
| 869 | # Compute instance loss | 864 | # Compute instance loss |
| 870 | loss = get_loss(noise_pred, noise, latents, timesteps) | 865 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
| 871 | 866 | ||
| 872 | # Compute prior loss | 867 | # Compute prior loss |
| 873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 868 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| 874 | 869 | ||
| 875 | # Add the prior loss to the instance loss. | 870 | # Add the prior loss to the instance loss. |
| 876 | loss = loss + args.prior_loss_weight * prior_loss | 871 | loss = loss + args.prior_loss_weight * prior_loss |
| 877 | else: | 872 | else: |
| 878 | loss = get_loss(noise_pred, noise, latents, timesteps) | 873 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 879 | 874 | ||
| 880 | accelerator.backward(loss) | 875 | accelerator.backward(loss) |
| 881 | 876 | ||
| @@ -952,11 +947,19 @@ def main(): | |||
| 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 947 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 948 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 954 | 949 | ||
| 955 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 950 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 951 | |||
| 952 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 956 | 953 | ||
| 957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 954 | # Get the target for loss depending on the prediction type |
| 955 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 956 | target = noise | ||
| 957 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 958 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 959 | else: | ||
| 960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 958 | 961 | ||
| 959 | loss = get_loss(noise_pred, noise, latents, timesteps) | 962 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 960 | 963 | ||
| 961 | loss = loss.detach().item() | 964 | loss = loss.detach().item() |
| 962 | val_loss += loss | 965 | val_loss += loss |
