diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-27 11:02:49 +0100 |
| commit | 9d6c75262b6919758e781b8333428861a5bf7ede (patch) | |
| tree | 72e5814413c18d476813867d87c8360c14aee200 /train_dreambooth.py | |
| parent | Set default dimensions to 768; add config inheritance (diff) | |
| download | textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.gz textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.tar.bz2 textual-inversion-diff-9d6c75262b6919758e781b8333428861a5bf7ede.zip | |
Added learning rate finder
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 129 |
1 files changed, 54 insertions, 75 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -843,6 +843,58 @@ def main(): | |||
| 843 | ) | 843 | ) |
| 844 | global_progress_bar.set_description("Total progress") | 844 | global_progress_bar.set_description("Total progress") |
| 845 | 845 | ||
| 846 | def loop(batch): | ||
| 847 | # Convert images to latent space | ||
| 848 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 849 | latents = latents * 0.18215 | ||
| 850 | |||
| 851 | # Sample noise that we'll add to the latents | ||
| 852 | noise = torch.randn_like(latents) | ||
| 853 | bsz = latents.shape[0] | ||
| 854 | # Sample a random timestep for each image | ||
| 855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 856 | (bsz,), device=latents.device) | ||
| 857 | timesteps = timesteps.long() | ||
| 858 | |||
| 859 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 860 | # (this is the forward diffusion process) | ||
| 861 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 862 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 863 | |||
| 864 | # Get the text embedding for conditioning | ||
| 865 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 866 | |||
| 867 | # Predict the noise residual | ||
| 868 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 869 | |||
| 870 | # Get the target for loss depending on the prediction type | ||
| 871 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 872 | target = noise | ||
| 873 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 874 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 875 | else: | ||
| 876 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 877 | |||
| 878 | if args.num_class_images != 0: | ||
| 879 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 880 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 881 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 882 | |||
| 883 | # Compute instance loss | ||
| 884 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 885 | |||
| 886 | # Compute prior loss | ||
| 887 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 888 | |||
| 889 | # Add the prior loss to the instance loss. | ||
| 890 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 891 | else: | ||
| 892 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 893 | |||
| 894 | acc = (model_pred == latents).float().mean() | ||
| 895 | |||
| 896 | return loss, acc, bsz | ||
| 897 | |||
| 846 | try: | 898 | try: |
| 847 | for epoch in range(num_epochs): | 899 | for epoch in range(num_epochs): |
| 848 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 900 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -859,54 +911,7 @@ def main(): | |||
| 859 | 911 | ||
| 860 | for step, batch in enumerate(train_dataloader): | 912 | for step, batch in enumerate(train_dataloader): |
| 861 | with accelerator.accumulate(unet): | 913 | with accelerator.accumulate(unet): |
| 862 | # Convert images to latent space | 914 | loss, acc, bsz = loop(batch) |
| 863 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 864 | latents = latents * 0.18215 | ||
| 865 | |||
| 866 | # Sample noise that we'll add to the latents | ||
| 867 | noise = torch.randn_like(latents) | ||
| 868 | bsz = latents.shape[0] | ||
| 869 | # Sample a random timestep for each image | ||
| 870 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 871 | (bsz,), device=latents.device) | ||
| 872 | timesteps = timesteps.long() | ||
| 873 | |||
| 874 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 875 | # (this is the forward diffusion process) | ||
| 876 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 877 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 878 | |||
| 879 | # Get the text embedding for conditioning | ||
| 880 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 881 | |||
| 882 | # Predict the noise residual | ||
| 883 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 884 | |||
| 885 | # Get the target for loss depending on the prediction type | ||
| 886 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 887 | target = noise | ||
| 888 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 889 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 890 | else: | ||
| 891 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 892 | |||
| 893 | if args.num_class_images != 0: | ||
| 894 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 895 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 896 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 897 | |||
| 898 | # Compute instance loss | ||
| 899 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | ||
| 900 | |||
| 901 | # Compute prior loss | ||
| 902 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 903 | |||
| 904 | # Add the prior loss to the instance loss. | ||
| 905 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 906 | else: | ||
| 907 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 908 | |||
| 909 | acc = (model_pred == latents).float().mean() | ||
| 910 | 915 | ||
| 911 | accelerator.backward(loss) | 916 | accelerator.backward(loss) |
| 912 | 917 | ||
| @@ -960,33 +965,7 @@ def main(): | |||
| 960 | 965 | ||
| 961 | with torch.inference_mode(): | 966 | with torch.inference_mode(): |
| 962 | for step, batch in enumerate(val_dataloader): | 967 | for step, batch in enumerate(val_dataloader): |
| 963 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 968 | loss, acc, bsz = loop(batch) |
| 964 | latents = latents * 0.18215 | ||
| 965 | |||
| 966 | noise = torch.randn_like(latents) | ||
| 967 | bsz = latents.shape[0] | ||
| 968 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 969 | (bsz,), device=latents.device) | ||
| 970 | timesteps = timesteps.long() | ||
| 971 | |||
| 972 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 973 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 974 | |||
| 975 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 976 | |||
| 977 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 978 | |||
| 979 | # Get the target for loss depending on the prediction type | ||
| 980 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 981 | target = noise | ||
| 982 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 983 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 984 | else: | ||
| 985 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 986 | |||
| 987 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 988 | |||
| 989 | acc = (model_pred == latents).float().mean() | ||
| 990 | 969 | ||
| 991 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
| 992 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
