summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-22 11:05:12 +0200
committerVolpeon <git@volpeon.ink>2022-10-22 11:05:12 +0200
commit0995b6d9b16c3b0ac4971e0d2ef4cf8f3ee050e8 (patch)
treece50fd284a73d9f62041fc997c0db4d7ea5f04be /dreambooth.py
parentAdd optional TI functionality to Dreambooth (diff)
downloadtextual-inversion-diff-0995b6d9b16c3b0ac4971e0d2ef4cf8f3ee050e8.tar.gz
textual-inversion-diff-0995b6d9b16c3b0ac4971e0d2ef4cf8f3ee050e8.tar.bz2
textual-inversion-diff-0995b6d9b16c3b0ac4971e0d2ef4cf8f3ee050e8.zip
Training update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py161
1 files changed, 75 insertions, 86 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 72c56cd..1539e81 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -810,6 +810,75 @@ def main():
810 ) 810 )
811 global_progress_bar.set_description("Total progress") 811 global_progress_bar.set_description("Total progress")
812 812
813 def run_step(batch, train=False, class_images=False):
814 # Convert images to latent space
815 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
816 latents = latents * 0.18215
817
818 # Sample noise that we'll add to the latents
819 noise = torch.randn_like(latents)
820 bsz = latents.shape[0]
821 # Sample a random timestep for each image
822 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
823 timesteps = timesteps.long()
824
825 # Add noise to the latents according to the noise magnitude at each timestep
826 # (this is the forward diffusion process)
827 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
828
829 # Get the text embedding for conditioning
830 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
831
832 # Predict the noise residual
833 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
834
835 if class_images:
836 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
837 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
838 noise, noise_prior = torch.chunk(noise, 2, dim=0)
839
840 # Compute instance loss
841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
842
843 # Compute prior loss
844 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
845
846 # Add the prior loss to the instance loss.
847 loss = loss + args.prior_loss_weight * prior_loss
848 else:
849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
850
851 if train:
852 accelerator.backward(loss)
853
854 if args.initializer_token is not None:
855 # Keep the token embeddings fixed except the newly added
856 # embeddings for the concept, as we only want to optimize the concept embeddings
857 if accelerator.num_processes > 1:
858 token_embeds = text_encoder.module.get_input_embeddings().weight
859 else:
860 token_embeds = text_encoder.get_input_embeddings().weight
861
862 # Get the index for tokens that we want to freeze
863 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
864 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
865
866 if accelerator.sync_gradients:
867 params_to_clip = (
868 unet.parameters()
869 if args.initializer_token is not None
870 else itertools.chain(unet.parameters(), text_encoder.parameters())
871 )
872 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
873
874 optimizer.step()
875 if not accelerator.optimizer_step_was_skipped:
876 lr_scheduler.step()
877 optimizer.zero_grad(set_to_none=True)
878
879 loss = loss.detach().item()
880 return loss
881
813 try: 882 try:
814 for epoch in range(num_epochs): 883 for epoch in range(num_epochs):
815 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 884 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -823,72 +892,11 @@ def main():
823 892
824 for step, batch in enumerate(train_dataloader): 893 for step, batch in enumerate(train_dataloader):
825 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 894 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
826 # Convert images to latent space 895 loss = run_step(
827 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 896 batch,
828 latents = latents * 0.18215 897 train=True,
829 898 class_images=args.num_class_images != 0
830 # Sample noise that we'll add to the latents 899 )
831 noise = torch.randn_like(latents)
832 bsz = latents.shape[0]
833 # Sample a random timestep for each image
834 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
835 (bsz,), device=latents.device)
836 timesteps = timesteps.long()
837
838 # Add noise to the latents according to the noise magnitude at each timestep
839 # (this is the forward diffusion process)
840 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
841
842 # Get the text embedding for conditioning
843 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
844
845 # Predict the noise residual
846 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
847
848 if args.num_class_images != 0:
849 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
850 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
851 noise, noise_prior = torch.chunk(noise, 2, dim=0)
852
853 # Compute instance loss
854 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
855
856 # Compute prior loss
857 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
858
859 # Add the prior loss to the instance loss.
860 loss = loss + args.prior_loss_weight * prior_loss
861 else:
862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
863
864 accelerator.backward(loss)
865
866 if args.initializer_token is not None:
867 # Keep the token embeddings fixed except the newly added
868 # embeddings for the concept, as we only want to optimize the concept embeddings
869 if accelerator.num_processes > 1:
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
873
874 # Get the index for tokens that we want to freeze
875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
877
878 if accelerator.sync_gradients:
879 params_to_clip = (
880 unet.parameters()
881 if args.initializer_token is not None
882 else itertools.chain(unet.parameters(), text_encoder.parameters())
883 )
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
885
886 optimizer.step()
887 if not accelerator.optimizer_step_was_skipped:
888 lr_scheduler.step()
889 optimizer.zero_grad(set_to_none=True)
890
891 loss = loss.detach().item()
892 train_loss += loss 900 train_loss += loss
893 901
894 # Checks if the accelerator has performed an optimization step behind the scenes 902 # Checks if the accelerator has performed an optimization step behind the scenes
@@ -929,26 +937,7 @@ def main():
929 937
930 with torch.inference_mode(): 938 with torch.inference_mode():
931 for step, batch in enumerate(val_dataloader): 939 for step, batch in enumerate(val_dataloader):
932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 940 loss = run_step(batch)
933 latents = latents * 0.18215
934
935 noise = torch.randn_like(latents)
936 bsz = latents.shape[0]
937 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
938 (bsz,), device=latents.device)
939 timesteps = timesteps.long()
940
941 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
942
943 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
944
945 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
946
947 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
948
949 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
950
951 loss = loss.detach().item()
952 val_loss += loss 941 val_loss += loss
953 942
954 if accelerator.sync_gradients: 943 if accelerator.sync_gradients: