summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 11:02:49 +0100
commit9d6c75262b6919758e781b8333428861a5bf7ede (patch)
tree72e5814413c18d476813867d87c8360c14aee200 /train_dreambooth.py
parentSet default dimensions to 768; add config inheritance (diff)
downloadtextual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2
textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip
Added learning rate finder
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py129
1 files changed, 54 insertions, 75 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 08bc9e0..a62cec9 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -843,6 +843,58 @@ def main():
843 ) 843 )
844 global_progress_bar.set_description("Total progress") 844 global_progress_bar.set_description("Total progress")
845 845
846 def loop(batch):
847 # Convert images to latent space
848 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
849 latents = latents * 0.18215
850
851 # Sample noise that we'll add to the latents
852 noise = torch.randn_like(latents)
853 bsz = latents.shape[0]
854 # Sample a random timestep for each image
855 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
856 (bsz,), device=latents.device)
857 timesteps = timesteps.long()
858
859 # Add noise to the latents according to the noise magnitude at each timestep
860 # (this is the forward diffusion process)
861 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
862 noisy_latents = noisy_latents.to(dtype=unet.dtype)
863
864 # Get the text embedding for conditioning
865 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
866
867 # Predict the noise residual
868 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
869
870 # Get the target for loss depending on the prediction type
871 if noise_scheduler.config.prediction_type == "epsilon":
872 target = noise
873 elif noise_scheduler.config.prediction_type == "v_prediction":
874 target = noise_scheduler.get_velocity(latents, noise, timesteps)
875 else:
876 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
877
878 if args.num_class_images != 0:
879 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
880 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
881 target, target_prior = torch.chunk(target, 2, dim=0)
882
883 # Compute instance loss
884 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
885
886 # Compute prior loss
887 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
888
889 # Add the prior loss to the instance loss.
890 loss = loss + args.prior_loss_weight * prior_loss
891 else:
892 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
893
894 acc = (model_pred == latents).float().mean()
895
896 return loss, acc, bsz
897
846 try: 898 try:
847 for epoch in range(num_epochs): 899 for epoch in range(num_epochs):
848 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 900 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -859,54 +911,7 @@ def main():
859 911
860 for step, batch in enumerate(train_dataloader): 912 for step, batch in enumerate(train_dataloader):
861 with accelerator.accumulate(unet): 913 with accelerator.accumulate(unet):
862 # Convert images to latent space 914 loss, acc, bsz = loop(batch)
863 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
864 latents = latents * 0.18215
865
866 # Sample noise that we'll add to the latents
867 noise = torch.randn_like(latents)
868 bsz = latents.shape[0]
869 # Sample a random timestep for each image
870 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
871 (bsz,), device=latents.device)
872 timesteps = timesteps.long()
873
874 # Add noise to the latents according to the noise magnitude at each timestep
875 # (this is the forward diffusion process)
876 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
877 noisy_latents = noisy_latents.to(dtype=unet.dtype)
878
879 # Get the text embedding for conditioning
880 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
881
882 # Predict the noise residual
883 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
884
885 # Get the target for loss depending on the prediction type
886 if noise_scheduler.config.prediction_type == "epsilon":
887 target = noise
888 elif noise_scheduler.config.prediction_type == "v_prediction":
889 target = noise_scheduler.get_velocity(latents, noise, timesteps)
890 else:
891 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
892
893 if args.num_class_images != 0:
894 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
895 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
896 target, target_prior = torch.chunk(target, 2, dim=0)
897
898 # Compute instance loss
899 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
900
901 # Compute prior loss
902 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
903
904 # Add the prior loss to the instance loss.
905 loss = loss + args.prior_loss_weight * prior_loss
906 else:
907 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
908
909 acc = (model_pred == latents).float().mean()
910 915
911 accelerator.backward(loss) 916 accelerator.backward(loss)
912 917
@@ -960,33 +965,7 @@ def main():
960 965
961 with torch.inference_mode(): 966 with torch.inference_mode():
962 for step, batch in enumerate(val_dataloader): 967 for step, batch in enumerate(val_dataloader):
963 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 968 loss, acc, bsz = loop(batch)
964 latents = latents * 0.18215
965
966 noise = torch.randn_like(latents)
967 bsz = latents.shape[0]
968 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
969 (bsz,), device=latents.device)
970 timesteps = timesteps.long()
971
972 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
973 noisy_latents = noisy_latents.to(dtype=unet.dtype)
974
975 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
976
977 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
978
979 # Get the target for loss depending on the prediction type
980 if noise_scheduler.config.prediction_type == "epsilon":
981 target = noise
982 elif noise_scheduler.config.prediction_type == "v_prediction":
983 target = noise_scheduler.get_velocity(latents, noise, timesteps)
984 else:
985 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
986
987 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
988
989 acc = (model_pred == latents).float().mean()
990 969
991 avg_loss_val.update(loss.detach_(), bsz) 970 avg_loss_val.update(loss.detach_(), bsz)
992 avg_acc_val.update(acc.detach_(), bsz) 971 avg_acc_val.update(acc.detach_(), bsz)