diff options
author | Volpeon <git@volpeon.ink> | 2022-11-28 13:23:05 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-28 13:23:05 +0100 |
commit | b8ba49fe4c44aaaa30894e5abba22d3bbf94a562 (patch) | |
tree | 0f727ae52048f829852dc0dc21a33cf2dd2f904c /dreambooth.py | |
parent | Fixed and improved Textual Inversion training (diff) | |
download | textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.gz textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.bz2 textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.zip |
Fixed noise calculation for v-prediction
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index d15f1ee..e9f785c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -893,6 +893,18 @@ def main(): | |||
893 | ) | 893 | ) |
894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
895 | 895 | ||
896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
902 | target = alpha_t * noise - sigma_t * latents | ||
903 | else: | ||
904 | target = noise | ||
905 | |||
906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
907 | |||
896 | try: | 908 | try: |
897 | for epoch in range(num_epochs): | 909 | for epoch in range(num_epochs): |
898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -929,19 +941,20 @@ def main(): | |||
929 | 941 | ||
930 | if args.num_class_images != 0: | 942 | if args.num_class_images != 0: |
931 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
932 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
933 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
934 | 947 | ||
935 | # Compute instance loss | 948 | # Compute instance loss |
936 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 949 | loss = get_loss(noise_pred, noise, latents, timesteps) |
937 | 950 | ||
938 | # Compute prior loss | 951 | # Compute prior loss |
939 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
940 | 953 | ||
941 | # Add the prior loss to the instance loss. | 954 | # Add the prior loss to the instance loss. |
942 | loss = loss + args.prior_loss_weight * prior_loss | 955 | loss = loss + args.prior_loss_weight * prior_loss |
943 | else: | 956 | else: |
944 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 957 | loss = get_loss(noise_pred, noise, latents, timesteps) |
945 | 958 | ||
946 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
947 | 960 | ||
@@ -1034,7 +1047,7 @@ def main(): | |||
1034 | 1047 | ||
1035 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
1036 | 1049 | ||
1037 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1050 | loss = get_loss(noise_pred, noise, latents, timesteps) |
1038 | 1051 | ||
1039 | acc = (noise_pred == latents).float() | 1052 | acc = (noise_pred == latents).float() |
1040 | acc = acc.mean() | 1053 | acc = acc.mean() |