diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index d15f1ee..e9f785c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -893,6 +893,18 @@ def main(): | |||
| 893 | ) | 893 | ) |
| 894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
| 895 | 895 | ||
| 896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 902 | target = alpha_t * noise - sigma_t * latents | ||
| 903 | else: | ||
| 904 | target = noise | ||
| 905 | |||
| 906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 907 | |||
| 896 | try: | 908 | try: |
| 897 | for epoch in range(num_epochs): | 909 | for epoch in range(num_epochs): |
| 898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -929,19 +941,20 @@ def main(): | |||
| 929 | 941 | ||
| 930 | if args.num_class_images != 0: | 942 | if args.num_class_images != 0: |
| 931 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 932 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 933 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 934 | 947 | ||
| 935 | # Compute instance loss | 948 | # Compute instance loss |
| 936 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 949 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 937 | 950 | ||
| 938 | # Compute prior loss | 951 | # Compute prior loss |
| 939 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) |
| 940 | 953 | ||
| 941 | # Add the prior loss to the instance loss. | 954 | # Add the prior loss to the instance loss. |
| 942 | loss = loss + args.prior_loss_weight * prior_loss | 955 | loss = loss + args.prior_loss_weight * prior_loss |
| 943 | else: | 956 | else: |
| 944 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 957 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 945 | 958 | ||
| 946 | accelerator.backward(loss) | 959 | accelerator.backward(loss) |
| 947 | 960 | ||
| @@ -1034,7 +1047,7 @@ def main(): | |||
| 1034 | 1047 | ||
| 1035 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 1036 | 1049 | ||
| 1037 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1050 | loss = get_loss(noise_pred, noise, latents, timesteps) |
| 1038 | 1051 | ||
| 1039 | acc = (noise_pred == latents).float() | 1052 | acc = (noise_pred == latents).float() |
| 1040 | acc = acc.mean() | 1053 | acc = acc.mean() |
