summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--dreambooth.py161
2 files changed, 87 insertions, 76 deletions
diff --git a/data/csv.py b/data/csv.py
index 4c91ded..df15c5a 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule):
76 76
77 def prepare_data(self): 77 def prepare_data(self):
78 metadata = pd.read_json(self.data_file) 78 metadata = pd.read_json(self.data_file)
79 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] 79 metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True]
80 num_images = len(metadata) 80 num_images = len(metadata)
81 81
82 valid_set_size = int(num_images * 0.2) 82 valid_set_size = int(num_images * 0.2)
diff --git a/dreambooth.py b/dreambooth.py
index 1539e81..72c56cd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -810,75 +810,6 @@ def main():
810 ) 810 )
811 global_progress_bar.set_description("Total progress") 811 global_progress_bar.set_description("Total progress")
812 812
813 def run_step(batch, train=False, class_images=False):
814 # Convert images to latent space
815 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
816 latents = latents * 0.18215
817
818 # Sample noise that we'll add to the latents
819 noise = torch.randn_like(latents)
820 bsz = latents.shape[0]
821 # Sample a random timestep for each image
822 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
823 timesteps = timesteps.long()
824
825 # Add noise to the latents according to the noise magnitude at each timestep
826 # (this is the forward diffusion process)
827 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
828
829 # Get the text embedding for conditioning
830 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
831
832 # Predict the noise residual
833 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
834
835 if class_images:
836 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
837 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
838 noise, noise_prior = torch.chunk(noise, 2, dim=0)
839
840 # Compute instance loss
841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
842
843 # Compute prior loss
844 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
845
846 # Add the prior loss to the instance loss.
847 loss = loss + args.prior_loss_weight * prior_loss
848 else:
849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
850
851 if train:
852 accelerator.backward(loss)
853
854 if args.initializer_token is not None:
855 # Keep the token embeddings fixed except the newly added
856 # embeddings for the concept, as we only want to optimize the concept embeddings
857 if accelerator.num_processes > 1:
858 token_embeds = text_encoder.module.get_input_embeddings().weight
859 else:
860 token_embeds = text_encoder.get_input_embeddings().weight
861
862 # Get the index for tokens that we want to freeze
863 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
864 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
865
866 if accelerator.sync_gradients:
867 params_to_clip = (
868 unet.parameters()
869 if args.initializer_token is not None
870 else itertools.chain(unet.parameters(), text_encoder.parameters())
871 )
872 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
873
874 optimizer.step()
875 if not accelerator.optimizer_step_was_skipped:
876 lr_scheduler.step()
877 optimizer.zero_grad(set_to_none=True)
878
879 loss = loss.detach().item()
880 return loss
881
882 try: 813 try:
883 for epoch in range(num_epochs): 814 for epoch in range(num_epochs):
884 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 815 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -892,11 +823,72 @@ def main():
892 823
893 for step, batch in enumerate(train_dataloader): 824 for step, batch in enumerate(train_dataloader):
894 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 825 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
895 loss = run_step( 826 # Convert images to latent space
896 batch, 827 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
897 train=True, 828 latents = latents * 0.18215
898 class_images=args.num_class_images != 0 829
899 ) 830 # Sample noise that we'll add to the latents
831 noise = torch.randn_like(latents)
832 bsz = latents.shape[0]
833 # Sample a random timestep for each image
834 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
835 (bsz,), device=latents.device)
836 timesteps = timesteps.long()
837
838 # Add noise to the latents according to the noise magnitude at each timestep
839 # (this is the forward diffusion process)
840 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
841
842 # Get the text embedding for conditioning
843 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
844
845 # Predict the noise residual
846 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
847
848 if args.num_class_images != 0:
849 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
850 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
851 noise, noise_prior = torch.chunk(noise, 2, dim=0)
852
853 # Compute instance loss
854 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
855
856 # Compute prior loss
857 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
858
859 # Add the prior loss to the instance loss.
860 loss = loss + args.prior_loss_weight * prior_loss
861 else:
862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
863
864 accelerator.backward(loss)
865
866 if args.initializer_token is not None:
867 # Keep the token embeddings fixed except the newly added
868 # embeddings for the concept, as we only want to optimize the concept embeddings
869 if accelerator.num_processes > 1:
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
873
874 # Get the index for tokens that we want to freeze
875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
877
878 if accelerator.sync_gradients:
879 params_to_clip = (
880 unet.parameters()
881 if args.initializer_token is not None
882 else itertools.chain(unet.parameters(), text_encoder.parameters())
883 )
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
885
886 optimizer.step()
887 if not accelerator.optimizer_step_was_skipped:
888 lr_scheduler.step()
889 optimizer.zero_grad(set_to_none=True)
890
891 loss = loss.detach().item()
900 train_loss += loss 892 train_loss += loss
901 893
902 # Checks if the accelerator has performed an optimization step behind the scenes 894 # Checks if the accelerator has performed an optimization step behind the scenes
@@ -937,7 +929,26 @@ def main():
937 929
938 with torch.inference_mode(): 930 with torch.inference_mode():
939 for step, batch in enumerate(val_dataloader): 931 for step, batch in enumerate(val_dataloader):
940 loss = run_step(batch) 932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
933 latents = latents * 0.18215
934
935 noise = torch.randn_like(latents)
936 bsz = latents.shape[0]
937 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
938 (bsz,), device=latents.device)
939 timesteps = timesteps.long()
940
941 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
942
943 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
944
945 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
946
947 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
948
949 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
950
951 loss = loss.detach().item()
941 val_loss += loss 952 val_loss += loss
942 953
943 if accelerator.sync_gradients: 954 if accelerator.sync_gradients: