summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
committerVolpeon <git@volpeon.ink>2022-11-28 13:23:05 +0100
commitb8ba49fe4c44aaaa30894e5abba22d3bbf94a562 (patch)
tree0f727ae52048f829852dc0dc21a33cf2dd2f904c /dreambooth.py
parentFixed and improved Textual Inversion training (diff)
downloadtextual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.gz
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.tar.bz2
textual-inversion-diff-b8ba49fe4c44aaaa30894e5abba22d3bbf94a562.zip
Fixed noise calculation for v-prediction
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index d15f1ee..e9f785c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -893,6 +893,18 @@ def main():
893 ) 893 )
894 global_progress_bar.set_description("Total progress") 894 global_progress_bar.set_description("Total progress")
895 895
896 def get_loss(noise_pred, noise, latents, timesteps):
897 if noise_scheduler.config.prediction_type == "v_prediction":
898 timesteps = timesteps.view(-1, 1, 1, 1)
899 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
900 alpha_t = torch.sqrt(alphas_cumprod)
901 sigma_t = torch.sqrt(1 - alphas_cumprod)
902 target = alpha_t * noise - sigma_t * latents
903 else:
904 target = noise
905
906 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
907
896 try: 908 try:
897 for epoch in range(num_epochs): 909 for epoch in range(num_epochs):
898 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 910 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -929,19 +941,20 @@ def main():
929 941
930 if args.num_class_images != 0: 942 if args.num_class_images != 0:
931 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 943 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
944 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0)
932 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 945 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
933 noise, noise_prior = torch.chunk(noise, 2, dim=0) 946 noise, noise_prior = torch.chunk(noise, 2, dim=0)
934 947
935 # Compute instance loss 948 # Compute instance loss
936 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 949 loss = get_loss(noise_pred, noise, latents, timesteps)
937 950
938 # Compute prior loss 951 # Compute prior loss
939 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 952 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps)
940 953
941 # Add the prior loss to the instance loss. 954 # Add the prior loss to the instance loss.
942 loss = loss + args.prior_loss_weight * prior_loss 955 loss = loss + args.prior_loss_weight * prior_loss
943 else: 956 else:
944 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 957 loss = get_loss(noise_pred, noise, latents, timesteps)
945 958
946 accelerator.backward(loss) 959 accelerator.backward(loss)
947 960
@@ -1034,7 +1047,7 @@ def main():
1034 1047
1035 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 1048 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
1036 1049
1037 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 1050 loss = get_loss(noise_pred, noise, latents, timesteps)
1038 1051
1039 acc = (noise_pred == latents).float() 1052 acc = (noise_pred == latents).float()
1040 acc = acc.mean() 1053 acc = acc.mean()