diff options
-rw-r--r-- | dreambooth.py | 161 | ||||
-rw-r--r-- | textual_inversion.py | 1 |
2 files changed, 75 insertions, 87 deletions
diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..1539e81 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -810,6 +810,75 @@ def main(): | |||
810 | ) | 810 | ) |
811 | global_progress_bar.set_description("Total progress") | 811 | global_progress_bar.set_description("Total progress") |
812 | 812 | ||
813 | def run_step(batch, train=False, class_images=False): | ||
814 | # Convert images to latent space | ||
815 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
816 | latents = latents * 0.18215 | ||
817 | |||
818 | # Sample noise that we'll add to the latents | ||
819 | noise = torch.randn_like(latents) | ||
820 | bsz = latents.shape[0] | ||
821 | # Sample a random timestep for each image | ||
822 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | ||
823 | timesteps = timesteps.long() | ||
824 | |||
825 | # Add noise to the latents according to the noise magnitude at each timestep | ||
826 | # (this is the forward diffusion process) | ||
827 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
828 | |||
829 | # Get the text embedding for conditioning | ||
830 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
831 | |||
832 | # Predict the noise residual | ||
833 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
834 | |||
835 | if class_images: | ||
836 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
837 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
838 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
839 | |||
840 | # Compute instance loss | ||
841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | ||
842 | |||
843 | # Compute prior loss | ||
844 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | ||
845 | |||
846 | # Add the prior loss to the instance loss. | ||
847 | loss = loss + args.prior_loss_weight * prior_loss | ||
848 | else: | ||
849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
850 | |||
851 | if train: | ||
852 | accelerator.backward(loss) | ||
853 | |||
854 | if args.initializer_token is not None: | ||
855 | # Keep the token embeddings fixed except the newly added | ||
856 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
857 | if accelerator.num_processes > 1: | ||
858 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
859 | else: | ||
860 | token_embeds = text_encoder.get_input_embeddings().weight | ||
861 | |||
862 | # Get the index for tokens that we want to freeze | ||
863 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
864 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
865 | |||
866 | if accelerator.sync_gradients: | ||
867 | params_to_clip = ( | ||
868 | unet.parameters() | ||
869 | if args.initializer_token is not None | ||
870 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
871 | ) | ||
872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
873 | |||
874 | optimizer.step() | ||
875 | if not accelerator.optimizer_step_was_skipped: | ||
876 | lr_scheduler.step() | ||
877 | optimizer.zero_grad(set_to_none=True) | ||
878 | |||
879 | loss = loss.detach().item() | ||
880 | return loss | ||
881 | |||
813 | try: | 882 | try: |
814 | for epoch in range(num_epochs): | 883 | for epoch in range(num_epochs): |
815 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 884 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -823,72 +892,11 @@ def main(): | |||
823 | 892 | ||
824 | for step, batch in enumerate(train_dataloader): | 893 | for step, batch in enumerate(train_dataloader): |
825 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 894 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
826 | # Convert images to latent space | 895 | loss = run_step( |
827 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 896 | batch, |
828 | latents = latents * 0.18215 | 897 | train=True, |
829 | 898 | class_images=args.num_class_images != 0 | |
830 | # Sample noise that we'll add to the latents | 899 | ) |
831 | noise = torch.randn_like(latents) | ||
832 | bsz = latents.shape[0] | ||
833 | # Sample a random timestep for each image | ||
834 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
835 | (bsz,), device=latents.device) | ||
836 | timesteps = timesteps.long() | ||
837 | |||
838 | # Add noise to the latents according to the noise magnitude at each timestep | ||
839 | # (this is the forward diffusion process) | ||
840 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
841 | |||
842 | # Get the text embedding for conditioning | ||
843 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
844 | |||
845 | # Predict the noise residual | ||
846 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
847 | |||
848 | if args.num_class_images != 0: | ||
849 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
850 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
851 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
852 | |||
853 | # Compute instance loss | ||
854 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | ||
855 | |||
856 | # Compute prior loss | ||
857 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | ||
858 | |||
859 | # Add the prior loss to the instance loss. | ||
860 | loss = loss + args.prior_loss_weight * prior_loss | ||
861 | else: | ||
862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
863 | |||
864 | accelerator.backward(loss) | ||
865 | |||
866 | if args.initializer_token is not None: | ||
867 | # Keep the token embeddings fixed except the newly added | ||
868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
869 | if accelerator.num_processes > 1: | ||
870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
871 | else: | ||
872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
873 | |||
874 | # Get the index for tokens that we want to freeze | ||
875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
877 | |||
878 | if accelerator.sync_gradients: | ||
879 | params_to_clip = ( | ||
880 | unet.parameters() | ||
881 | if args.initializer_token is not None | ||
882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
883 | ) | ||
884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
885 | |||
886 | optimizer.step() | ||
887 | if not accelerator.optimizer_step_was_skipped: | ||
888 | lr_scheduler.step() | ||
889 | optimizer.zero_grad(set_to_none=True) | ||
890 | |||
891 | loss = loss.detach().item() | ||
892 | train_loss += loss | 900 | train_loss += loss |
893 | 901 | ||
894 | # Checks if the accelerator has performed an optimization step behind the scenes | 902 | # Checks if the accelerator has performed an optimization step behind the scenes |
@@ -929,26 +937,7 @@ def main(): | |||
929 | 937 | ||
930 | with torch.inference_mode(): | 938 | with torch.inference_mode(): |
931 | for step, batch in enumerate(val_dataloader): | 939 | for step, batch in enumerate(val_dataloader): |
932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 940 | loss = run_step(batch) |
933 | latents = latents * 0.18215 | ||
934 | |||
935 | noise = torch.randn_like(latents) | ||
936 | bsz = latents.shape[0] | ||
937 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
938 | (bsz,), device=latents.device) | ||
939 | timesteps = timesteps.long() | ||
940 | |||
941 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
942 | |||
943 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
944 | |||
945 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
946 | |||
947 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
948 | |||
949 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
950 | |||
951 | loss = loss.detach().item() | ||
952 | val_loss += loss | 941 | val_loss += loss |
953 | 942 | ||
954 | if accelerator.sync_gradients: | 943 | if accelerator.sync_gradients: |
diff --git a/textual_inversion.py b/textual_inversion.py index fe56d36..c42762f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -520,7 +520,6 @@ def main(): | |||
520 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 520 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
521 | 521 | ||
522 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
523 | unet.enable_gradient_checkpointing() | ||
524 | text_encoder.gradient_checkpointing_enable() | 523 | text_encoder.gradient_checkpointing_enable() |
525 | 524 | ||
526 | # slice_size = unet.config.attention_head_dim // 2 | 525 | # slice_size = unet.config.attention_head_dim // 2 |