summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-28 20:27:56 +0100
committerVolpeon <git@volpeon.ink>2022-11-28 20:27:56 +0100
commit1386c7badd2930f8a8f8f649216a25f3809a4d96 (patch)
tree684b487151be99b8dde8848a2886c0aae3a8d017 /dreambooth.py
parentFixed noise calculation for v-prediction (diff)
downloadtextual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.gz
textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.tar.bz2
textual-inversion-diff-1386c7badd2930f8a8f8f649216a25f3809a4d96.zip
Adjusted training to upstream
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py57
1 files changed, 30 insertions, 27 deletions
diff --git a/dreambooth.py b/dreambooth.py
index e9f785c..49d4447 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -893,18 +893,6 @@ def main():
893 ) 893 )
894 global_progress_bar.set_description("Total progress") 894 global_progress_bar.set_description("Total progress")
895 895
896 def get_loss(noise_pred, noise, latents, timesteps):
897 if noise_scheduler.config.prediction_type == "v_prediction":
898 timesteps = timesteps.view(-1, 1, 1, 1)
899 alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps]
900 alpha_t = torch.sqrt(alphas_cumprod)
901 sigma_t = torch.sqrt(1 - alphas_cumprod)
902 target = alpha_t * noise - sigma_t * latents
903 else:
904 target = noise
905
906 return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
907
908 try: 896 try:
909 for epoch in range(num_epochs): 897 for epoch in range(num_epochs):
910 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 898 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -937,24 +925,31 @@ def main():
937 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 925 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
938 926
939 # Predict the noise residual 927 # Predict the noise residual
940 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 928 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
929
930 # Get the target for loss depending on the prediction type
931 if noise_scheduler.config.prediction_type == "epsilon":
932 target = noise
933 elif noise_scheduler.config.prediction_type == "v_prediction":
934 target = noise_scheduler.get_velocity(latents, noise, timesteps)
935 else:
936 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
941 937
942 if args.num_class_images != 0: 938 if args.num_class_images != 0:
943 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 939 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
944 latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) 940 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
945 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 941 target, target_prior = torch.chunk(target, 2, dim=0)
946 noise, noise_prior = torch.chunk(noise, 2, dim=0)
947 942
948 # Compute instance loss 943 # Compute instance loss
949 loss = get_loss(noise_pred, noise, latents, timesteps) 944 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
950 945
951 # Compute prior loss 946 # Compute prior loss
952 prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) 947 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
953 948
954 # Add the prior loss to the instance loss. 949 # Add the prior loss to the instance loss.
955 loss = loss + args.prior_loss_weight * prior_loss 950 loss = loss + args.prior_loss_weight * prior_loss
956 else: 951 else:
957 loss = get_loss(noise_pred, noise, latents, timesteps) 952 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
958 953
959 accelerator.backward(loss) 954 accelerator.backward(loss)
960 955
@@ -985,7 +980,7 @@ def main():
985 ema_unet.step(unet) 980 ema_unet.step(unet)
986 optimizer.zero_grad(set_to_none=True) 981 optimizer.zero_grad(set_to_none=True)
987 982
988 acc = (noise_pred == latents).float() 983 acc = (model_pred == latents).float()
989 acc = acc.mean() 984 acc = acc.mean()
990 985
991 total_loss += loss.item() 986 total_loss += loss.item()
@@ -1006,8 +1001,8 @@ def main():
1006 sample_checkpoint = True 1001 sample_checkpoint = True
1007 1002
1008 logs = { 1003 logs = {
1009 "train/loss": total_loss / global_step, 1004 "train/loss": total_loss / global_step if global_step != 0 else 0,
1010 "train/acc": total_acc / global_step, 1005 "train/acc": total_acc / global_step if global_step != 0 else 0,
1011 "train/cur_loss": loss.item(), 1006 "train/cur_loss": loss.item(),
1012 "train/cur_acc": acc.item(), 1007 "train/cur_acc": acc.item(),
1013 "lr/unet": lr_scheduler.get_last_lr()[0], 1008 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -1043,13 +1038,21 @@ def main():
1043 1038
1044 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 1039 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
1045 1040
1046 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1041 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1047 1042
1048 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 1043 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
1044
1045 # Get the target for loss depending on the prediction type
1046 if noise_scheduler.config.prediction_type == "epsilon":
1047 target = noise
1048 elif noise_scheduler.config.prediction_type == "v_prediction":
1049 target = noise_scheduler.get_velocity(latents, noise, timesteps)
1050 else:
1051 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1049 1052
1050 loss = get_loss(noise_pred, noise, latents, timesteps) 1053 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1051 1054
1052 acc = (noise_pred == latents).float() 1055 acc = (model_pred == latents).float()
1053 acc = acc.mean() 1056 acc = acc.mean()
1054 1057
1055 total_loss_val += loss.item() 1058 total_loss_val += loss.item()