diff options
| -rw-r--r-- | dreambooth.py | 57 | ||||
| -rw-r--r-- | textual_inversion.py | 49 |
2 files changed, 56 insertions, 50 deletions
diff --git a/dreambooth.py b/dreambooth.py index e9f785c..49d4447 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -893,18 +893,6 @@ def main(): | |||
| 893 | ) | 893 | ) |
| 894 | global_progress_bar.set_description("Total progress") | 894 | global_progress_bar.set_description("Total progress") |
| 895 | 895 | ||
| 896 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 897 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 898 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 899 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 900 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 901 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 902 | target = alpha_t * noise - sigma_t * latents | ||
| 903 | else: | ||
| 904 | target = noise | ||
| 905 | |||
| 906 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 907 | |||
| 908 | try: | 896 | try: |
| 909 | for epoch in range(num_epochs): | 897 | for epoch in range(num_epochs): |
| 910 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 898 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -937,24 +925,31 @@ def main(): | |||
| 937 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 925 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 938 | 926 | ||
| 939 | # Predict the noise residual | 927 | # Predict the noise residual |
| 940 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 928 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 929 | |||
| 930 | # Get the target for loss depending on the prediction type | ||
| 931 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 932 | target = noise | ||
| 933 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 934 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 935 | else: | ||
| 936 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 941 | 937 | ||
| 942 | if args.num_class_images != 0: | 938 | if args.num_class_images != 0: |
| 943 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 939 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 944 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 940 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 945 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 941 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 946 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 947 | 942 | ||
| 948 | # Compute instance loss | 943 | # Compute instance loss |
| 949 | loss = get_loss(noise_pred, noise, latents, timesteps) | 944 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
| 950 | 945 | ||
| 951 | # Compute prior loss | 946 | # Compute prior loss |
| 952 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 947 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| 953 | 948 | ||
| 954 | # Add the prior loss to the instance loss. | 949 | # Add the prior loss to the instance loss. |
| 955 | loss = loss + args.prior_loss_weight * prior_loss | 950 | loss = loss + args.prior_loss_weight * prior_loss |
| 956 | else: | 951 | else: |
| 957 | loss = get_loss(noise_pred, noise, latents, timesteps) | 952 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 958 | 953 | ||
| 959 | accelerator.backward(loss) | 954 | accelerator.backward(loss) |
| 960 | 955 | ||
| @@ -985,7 +980,7 @@ def main(): | |||
| 985 | ema_unet.step(unet) | 980 | ema_unet.step(unet) |
| 986 | optimizer.zero_grad(set_to_none=True) | 981 | optimizer.zero_grad(set_to_none=True) |
| 987 | 982 | ||
| 988 | acc = (noise_pred == latents).float() | 983 | acc = (model_pred == latents).float() |
| 989 | acc = acc.mean() | 984 | acc = acc.mean() |
| 990 | 985 | ||
| 991 | total_loss += loss.item() | 986 | total_loss += loss.item() |
| @@ -1006,8 +1001,8 @@ def main(): | |||
| 1006 | sample_checkpoint = True | 1001 | sample_checkpoint = True |
| 1007 | 1002 | ||
| 1008 | logs = { | 1003 | logs = { |
| 1009 | "train/loss": total_loss / global_step, | 1004 | "train/loss": total_loss / global_step if global_step != 0 else 0, |
| 1010 | "train/acc": total_acc / global_step, | 1005 | "train/acc": total_acc / global_step if global_step != 0 else 0, |
| 1011 | "train/cur_loss": loss.item(), | 1006 | "train/cur_loss": loss.item(), |
| 1012 | "train/cur_acc": acc.item(), | 1007 | "train/cur_acc": acc.item(), |
| 1013 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1008 | "lr/unet": lr_scheduler.get_last_lr()[0], |
| @@ -1043,13 +1038,21 @@ def main(): | |||
| 1043 | 1038 | ||
| 1044 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 1039 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 1045 | 1040 | ||
| 1046 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1041 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 1042 | |||
| 1043 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 1047 | 1044 | ||
| 1048 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 1045 | # Get the target for loss depending on the prediction type |
| 1046 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 1047 | target = noise | ||
| 1048 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 1049 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 1050 | else: | ||
| 1051 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 1049 | 1052 | ||
| 1050 | loss = get_loss(noise_pred, noise, latents, timesteps) | 1053 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 1051 | 1054 | ||
| 1052 | acc = (noise_pred == latents).float() | 1055 | acc = (model_pred == latents).float() |
| 1053 | acc = acc.mean() | 1056 | acc = acc.mean() |
| 1054 | 1057 | ||
| 1055 | total_loss_val += loss.item() | 1058 | total_loss_val += loss.item() |
diff --git a/textual_inversion.py b/textual_inversion.py index fa7ae42..7ac9638 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -813,18 +813,6 @@ def main(): | |||
| 813 | ) | 813 | ) |
| 814 | global_progress_bar.set_description("Total progress") | 814 | global_progress_bar.set_description("Total progress") |
| 815 | 815 | ||
| 816 | def get_loss(noise_pred, noise, latents, timesteps): | ||
| 817 | if noise_scheduler.config.prediction_type == "v_prediction": | ||
| 818 | timesteps = timesteps.view(-1, 1, 1, 1) | ||
| 819 | alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] | ||
| 820 | alpha_t = torch.sqrt(alphas_cumprod) | ||
| 821 | sigma_t = torch.sqrt(1 - alphas_cumprod) | ||
| 822 | target = alpha_t * noise - sigma_t * latents | ||
| 823 | else: | ||
| 824 | target = noise | ||
| 825 | |||
| 826 | return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") | ||
| 827 | |||
| 828 | try: | 816 | try: |
| 829 | for epoch in range(num_epochs): | 817 | for epoch in range(num_epochs): |
| 830 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 818 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -858,24 +846,31 @@ def main(): | |||
| 858 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 846 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 859 | 847 | ||
| 860 | # Predict the noise residual | 848 | # Predict the noise residual |
| 861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 849 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 850 | |||
| 851 | # Get the target for loss depending on the prediction type | ||
| 852 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 853 | target = noise | ||
| 854 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 855 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 856 | else: | ||
| 857 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 862 | 858 | ||
| 863 | if args.num_class_images != 0: | 859 | if args.num_class_images != 0: |
| 864 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 860 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 865 | latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) | 861 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 866 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 862 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 867 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 868 | 863 | ||
| 869 | # Compute instance loss | 864 | # Compute instance loss |
| 870 | loss = get_loss(noise_pred, noise, latents, timesteps) | 865 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() |
| 871 | 866 | ||
| 872 | # Compute prior loss | 867 | # Compute prior loss |
| 873 | prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) | 868 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| 874 | 869 | ||
| 875 | # Add the prior loss to the instance loss. | 870 | # Add the prior loss to the instance loss. |
| 876 | loss = loss + args.prior_loss_weight * prior_loss | 871 | loss = loss + args.prior_loss_weight * prior_loss |
| 877 | else: | 872 | else: |
| 878 | loss = get_loss(noise_pred, noise, latents, timesteps) | 873 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 879 | 874 | ||
| 880 | accelerator.backward(loss) | 875 | accelerator.backward(loss) |
| 881 | 876 | ||
| @@ -952,11 +947,19 @@ def main(): | |||
| 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 947 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 948 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 954 | 949 | ||
| 955 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 950 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 951 | |||
| 952 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 956 | 953 | ||
| 957 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 954 | # Get the target for loss depending on the prediction type |
| 955 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 956 | target = noise | ||
| 957 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 958 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 959 | else: | ||
| 960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 958 | 961 | ||
| 959 | loss = get_loss(noise_pred, noise, latents, timesteps) | 962 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 960 | 963 | ||
| 961 | loss = loss.detach().item() | 964 | loss = loss.detach().item() |
| 962 | val_loss += loss | 965 | val_loss += loss |
