summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py153
-rw-r--r--textual_inversion.py1
2 files changed, 71 insertions, 83 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 72c56cd..1539e81 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -810,85 +810,93 @@ def main():
810 ) 810 )
811 global_progress_bar.set_description("Total progress") 811 global_progress_bar.set_description("Total progress")
812 812
813 try: 813 def run_step(batch, train=False, class_images=False):
814 for epoch in range(num_epochs): 814 # Convert images to latent space
815 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 815 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
816 local_progress_bar.reset() 816 latents = latents * 0.18215
817 817
818 unet.train() 818 # Sample noise that we'll add to the latents
819 text_encoder.train() 819 noise = torch.randn_like(latents)
820 train_loss = 0.0 820 bsz = latents.shape[0]
821 # Sample a random timestep for each image
822 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
823 timesteps = timesteps.long()
821 824
822 sample_checkpoint = False 825 # Add noise to the latents according to the noise magnitude at each timestep
826 # (this is the forward diffusion process)
827 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
823 828
824 for step, batch in enumerate(train_dataloader): 829 # Get the text embedding for conditioning
825 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 830 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
826 # Convert images to latent space
827 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
828 latents = latents * 0.18215
829 831
830 # Sample noise that we'll add to the latents 832 # Predict the noise residual
831 noise = torch.randn_like(latents) 833 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
832 bsz = latents.shape[0]
833 # Sample a random timestep for each image
834 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
835 (bsz,), device=latents.device)
836 timesteps = timesteps.long()
837 834
838 # Add noise to the latents according to the noise magnitude at each timestep 835 if class_images:
839 # (this is the forward diffusion process) 836 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
840 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 837 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
838 noise, noise_prior = torch.chunk(noise, 2, dim=0)
841 839
842 # Get the text embedding for conditioning 840 # Compute instance loss
843 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
844 842
845 # Predict the noise residual 843 # Compute prior loss
846 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 844 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
847 845
848 if args.num_class_images != 0: 846 # Add the prior loss to the instance loss.
849 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 847 loss = loss + args.prior_loss_weight * prior_loss
850 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 848 else:
851 noise, noise_prior = torch.chunk(noise, 2, dim=0) 849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
852 850
853 # Compute instance loss 851 if train:
854 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 852 accelerator.backward(loss)
855 853
856 # Compute prior loss 854 if args.initializer_token is not None:
857 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 855 # Keep the token embeddings fixed except the newly added
856 # embeddings for the concept, as we only want to optimize the concept embeddings
857 if accelerator.num_processes > 1:
858 token_embeds = text_encoder.module.get_input_embeddings().weight
859 else:
860 token_embeds = text_encoder.get_input_embeddings().weight
858 861
859 # Add the prior loss to the instance loss. 862 # Get the index for tokens that we want to freeze
860 loss = loss + args.prior_loss_weight * prior_loss 863 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
861 else: 864 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
863 865
864 accelerator.backward(loss) 866 if accelerator.sync_gradients:
867 params_to_clip = (
868 unet.parameters()
869 if args.initializer_token is not None
870 else itertools.chain(unet.parameters(), text_encoder.parameters())
871 )
872 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
865 873
866 if args.initializer_token is not None: 874 optimizer.step()
867 # Keep the token embeddings fixed except the newly added 875 if not accelerator.optimizer_step_was_skipped:
868 # embeddings for the concept, as we only want to optimize the concept embeddings 876 lr_scheduler.step()
869 if accelerator.num_processes > 1: 877 optimizer.zero_grad(set_to_none=True)
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
873 878
874 # Get the index for tokens that we want to freeze 879 loss = loss.detach().item()
875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id 880 return loss
876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
877 881
878 if accelerator.sync_gradients: 882 try:
879 params_to_clip = ( 883 for epoch in range(num_epochs):
880 unet.parameters() 884 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
881 if args.initializer_token is not None 885 local_progress_bar.reset()
882 else itertools.chain(unet.parameters(), text_encoder.parameters()) 886
883 ) 887 unet.train()
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 888 text_encoder.train()
889 train_loss = 0.0
885 890
886 optimizer.step() 891 sample_checkpoint = False
887 if not accelerator.optimizer_step_was_skipped:
888 lr_scheduler.step()
889 optimizer.zero_grad(set_to_none=True)
890 892
891 loss = loss.detach().item() 893 for step, batch in enumerate(train_dataloader):
894 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
895 loss = run_step(
896 batch,
897 train=True,
898 class_images=args.num_class_images != 0
899 )
892 train_loss += loss 900 train_loss += loss
893 901
894 # Checks if the accelerator has performed an optimization step behind the scenes 902 # Checks if the accelerator has performed an optimization step behind the scenes
@@ -929,26 +937,7 @@ def main():
929 937
930 with torch.inference_mode(): 938 with torch.inference_mode():
931 for step, batch in enumerate(val_dataloader): 939 for step, batch in enumerate(val_dataloader):
932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 940 loss = run_step(batch)
933 latents = latents * 0.18215
934
935 noise = torch.randn_like(latents)
936 bsz = latents.shape[0]
937 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
938 (bsz,), device=latents.device)
939 timesteps = timesteps.long()
940
941 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
942
943 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
944
945 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
946
947 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
948
949 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
950
951 loss = loss.detach().item()
952 val_loss += loss 941 val_loss += loss
953 942
954 if accelerator.sync_gradients: 943 if accelerator.sync_gradients:
diff --git a/textual_inversion.py b/textual_inversion.py
index fe56d36..c42762f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -520,7 +520,6 @@ def main():
520 prompt_processor = PromptProcessor(tokenizer, text_encoder) 520 prompt_processor = PromptProcessor(tokenizer, text_encoder)
521 521
522 if args.gradient_checkpointing: 522 if args.gradient_checkpointing:
523 unet.enable_gradient_checkpointing()
524 text_encoder.gradient_checkpointing_enable() 523 text_encoder.gradient_checkpointing_enable()
525 524
526 # slice_size = unet.config.attention_head_dim // 2 525 # slice_size = unet.config.attention_head_dim // 2