diff options
author | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
commit | 9d6c75262b6919758e781b8333428861a5bf7ede (patch) | |
tree | 72e5814413c18d476813867d87c8360c14aee200 /train_dreambooth.py | |
parent | Set default dimensions to 768; add config inheritance (diff) | |
download | textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2 textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip |
Added learning rate finder
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 129 |
1 files changed, 54 insertions, 75 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -843,6 +843,58 @@ def main(): | |||
843 | ) | 843 | ) |
844 | global_progress_bar.set_description("Total progress") | 844 | global_progress_bar.set_description("Total progress") |
845 | 845 | ||
846 | def loop(batch): | ||
847 | # Convert images to latent space | ||
848 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
849 | latents = latents * 0.18215 | ||
850 | |||
851 | # Sample noise that we'll add to the latents | ||
852 | noise = torch.randn_like(latents) | ||
853 | bsz = latents.shape[0] | ||
854 | # Sample a random timestep for each image | ||
855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
856 | (bsz,), device=latents.device) | ||
857 | timesteps = timesteps.long() | ||
858 | |||
859 | # Add noise to the latents according to the noise magnitude at each timestep | ||
860 | # (this is the forward diffusion process) | ||
861 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
862 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
863 | |||
864 | # Get the text embedding for conditioning | ||
865 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
866 | |||
867 | # Predict the noise residual | ||
868 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
869 | |||
870 | # Get the target for loss depending on the prediction type | ||
871 | if noise_scheduler.config.prediction_type == "epsilon": | ||
872 | target = noise | ||
873 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
874 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
875 | else: | ||
876 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
877 | |||
878 | if args.num_class_images != 0: | ||
879 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
880 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
881 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
882 | |||
883 | # Compute instance loss | ||
884 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
885 | |||
886 | # Compute prior loss | ||
887 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
888 | |||
889 | # Add the prior loss to the instance loss. | ||
890 | loss = loss + args.prior_loss_weight * prior_loss | ||
891 | else: | ||
892 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
893 | |||
894 | acc = (model_pred == latents).float().mean() | ||
895 | |||
896 | return loss, acc, bsz | ||
897 | |||
846 | try: | 898 | try: |
847 | for epoch in range(num_epochs): | 899 | for epoch in range(num_epochs): |
848 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 900 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -859,54 +911,7 @@ def main(): | |||
859 | 911 | ||
860 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
861 | with accelerator.accumulate(unet): | 913 | with accelerator.accumulate(unet): |
862 | # Convert images to latent space | 914 | loss, acc, bsz = loop(batch) |
863 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
864 | latents = latents * 0.18215 | ||
865 | |||
866 | # Sample noise that we'll add to the latents | ||
867 | noise = torch.randn_like(latents) | ||
868 | bsz = latents.shape[0] | ||
869 | # Sample a random timestep for each image | ||
870 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
871 | (bsz,), device=latents.device) | ||
872 | timesteps = timesteps.long() | ||
873 | |||
874 | # Add noise to the latents according to the noise magnitude at each timestep | ||
875 | # (this is the forward diffusion process) | ||
876 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
877 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
878 | |||
879 | # Get the text embedding for conditioning | ||
880 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
881 | |||
882 | # Predict the noise residual | ||
883 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
884 | |||
885 | # Get the target for loss depending on the prediction type | ||
886 | if noise_scheduler.config.prediction_type == "epsilon": | ||
887 | target = noise | ||
888 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
889 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
890 | else: | ||
891 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
892 | |||
893 | if args.num_class_images != 0: | ||
894 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
895 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
896 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
897 | |||
898 | # Compute instance loss | ||
899 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
900 | |||
901 | # Compute prior loss | ||
902 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
903 | |||
904 | # Add the prior loss to the instance loss. | ||
905 | loss = loss + args.prior_loss_weight * prior_loss | ||
906 | else: | ||
907 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
908 | |||
909 | acc = (model_pred == latents).float().mean() | ||
910 | 915 | ||
911 | accelerator.backward(loss) | 916 | accelerator.backward(loss) |
912 | 917 | ||
@@ -960,33 +965,7 @@ def main(): | |||
960 | 965 | ||
961 | with torch.inference_mode(): | 966 | with torch.inference_mode(): |
962 | for step, batch in enumerate(val_dataloader): | 967 | for step, batch in enumerate(val_dataloader): |
963 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 968 | loss, acc, bsz = loop(batch) |
964 | latents = latents * 0.18215 | ||
965 | |||
966 | noise = torch.randn_like(latents) | ||
967 | bsz = latents.shape[0] | ||
968 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
969 | (bsz,), device=latents.device) | ||
970 | timesteps = timesteps.long() | ||
971 | |||
972 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
973 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
974 | |||
975 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
976 | |||
977 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
978 | |||
979 | # Get the target for loss depending on the prediction type | ||
980 | if noise_scheduler.config.prediction_type == "epsilon": | ||
981 | target = noise | ||
982 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
983 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
984 | else: | ||
985 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
986 | |||
987 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
988 | |||
989 | acc = (model_pred == latents).float().mean() | ||
990 | 969 | ||
991 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
992 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |