summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-22 16:56:10 +0200
committerVolpeon <git@volpeon.ink>2022-10-22 16:56:10 +0200
commit46b1eda6d1c7db552ce5c577bed101c61f09d55b (patch)
tree80840bbb5a57238b3dabae1d7fa2588a69f79dd9
parentTraining update (diff)
downloadtextual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.tar.gz
textual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.tar.bz2
textual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.zip
Revert lat; fix skip attribute in dataset
-rw-r--r--data/csv.py2
-rw-r--r--dreambooth.py153
2 files changed, 83 insertions, 72 deletions
diff --git a/data/csv.py b/data/csv.py
index 4c91ded..df15c5a 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule):
76 76
77 def prepare_data(self): 77 def prepare_data(self):
78 metadata = pd.read_json(self.data_file) 78 metadata = pd.read_json(self.data_file)
79 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] 79 metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True]
80 num_images = len(metadata) 80 num_images = len(metadata)
81 81
82 valid_set_size = int(num_images * 0.2) 82 valid_set_size = int(num_images * 0.2)
diff --git a/dreambooth.py b/dreambooth.py
index 1539e81..72c56cd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -810,93 +810,85 @@ def main():
810 ) 810 )
811 global_progress_bar.set_description("Total progress") 811 global_progress_bar.set_description("Total progress")
812 812
813 def run_step(batch, train=False, class_images=False): 813 try:
814 # Convert images to latent space 814 for epoch in range(num_epochs):
815 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 815 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
816 latents = latents * 0.18215 816 local_progress_bar.reset()
817
818 # Sample noise that we'll add to the latents
819 noise = torch.randn_like(latents)
820 bsz = latents.shape[0]
821 # Sample a random timestep for each image
822 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
823 timesteps = timesteps.long()
824 817
825 # Add noise to the latents according to the noise magnitude at each timestep 818 unet.train()
826 # (this is the forward diffusion process) 819 text_encoder.train()
827 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 820 train_loss = 0.0
828 821
829 # Get the text embedding for conditioning 822 sample_checkpoint = False
830 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
831 823
832 # Predict the noise residual 824 for step, batch in enumerate(train_dataloader):
833 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 825 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
826 # Convert images to latent space
827 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
828 latents = latents * 0.18215
834 829
835 if class_images: 830 # Sample noise that we'll add to the latents
836 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 831 noise = torch.randn_like(latents)
837 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 832 bsz = latents.shape[0]
838 noise, noise_prior = torch.chunk(noise, 2, dim=0) 833 # Sample a random timestep for each image
834 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
835 (bsz,), device=latents.device)
836 timesteps = timesteps.long()
839 837
840 # Compute instance loss 838 # Add noise to the latents according to the noise magnitude at each timestep
841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 839 # (this is the forward diffusion process)
840 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
842 841
843 # Compute prior loss 842 # Get the text embedding for conditioning
844 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 843 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
845 844
846 # Add the prior loss to the instance loss. 845 # Predict the noise residual
847 loss = loss + args.prior_loss_weight * prior_loss 846 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
848 else:
849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
850 847
851 if train: 848 if args.num_class_images != 0:
852 accelerator.backward(loss) 849 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
850 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
851 noise, noise_prior = torch.chunk(noise, 2, dim=0)
853 852
854 if args.initializer_token is not None: 853 # Compute instance loss
855 # Keep the token embeddings fixed except the newly added 854 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
856 # embeddings for the concept, as we only want to optimize the concept embeddings
857 if accelerator.num_processes > 1:
858 token_embeds = text_encoder.module.get_input_embeddings().weight
859 else:
860 token_embeds = text_encoder.get_input_embeddings().weight
861 855
862 # Get the index for tokens that we want to freeze 856 # Compute prior loss
863 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id 857 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
864 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
865 858
866 if accelerator.sync_gradients: 859 # Add the prior loss to the instance loss.
867 params_to_clip = ( 860 loss = loss + args.prior_loss_weight * prior_loss
868 unet.parameters() 861 else:
869 if args.initializer_token is not None 862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
870 else itertools.chain(unet.parameters(), text_encoder.parameters())
871 )
872 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
873 863
874 optimizer.step() 864 accelerator.backward(loss)
875 if not accelerator.optimizer_step_was_skipped:
876 lr_scheduler.step()
877 optimizer.zero_grad(set_to_none=True)
878 865
879 loss = loss.detach().item() 866 if args.initializer_token is not None:
880 return loss 867 # Keep the token embeddings fixed except the newly added
868 # embeddings for the concept, as we only want to optimize the concept embeddings
869 if accelerator.num_processes > 1:
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
881 873
882 try: 874 # Get the index for tokens that we want to freeze
883 for epoch in range(num_epochs): 875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
884 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
885 local_progress_bar.reset()
886 877
887 unet.train() 878 if accelerator.sync_gradients:
888 text_encoder.train() 879 params_to_clip = (
889 train_loss = 0.0 880 unet.parameters()
881 if args.initializer_token is not None
882 else itertools.chain(unet.parameters(), text_encoder.parameters())
883 )
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
890 885
891 sample_checkpoint = False 886 optimizer.step()
887 if not accelerator.optimizer_step_was_skipped:
888 lr_scheduler.step()
889 optimizer.zero_grad(set_to_none=True)
892 890
893 for step, batch in enumerate(train_dataloader): 891 loss = loss.detach().item()
894 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
895 loss = run_step(
896 batch,
897 train=True,
898 class_images=args.num_class_images != 0
899 )
900 train_loss += loss 892 train_loss += loss
901 893
902 # Checks if the accelerator has performed an optimization step behind the scenes 894 # Checks if the accelerator has performed an optimization step behind the scenes
@@ -937,7 +929,26 @@ def main():
937 929
938 with torch.inference_mode(): 930 with torch.inference_mode():
939 for step, batch in enumerate(val_dataloader): 931 for step, batch in enumerate(val_dataloader):
940 loss = run_step(batch) 932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
933 latents = latents * 0.18215
934
935 noise = torch.randn_like(latents)
936 bsz = latents.shape[0]
937 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
938 (bsz,), device=latents.device)
939 timesteps = timesteps.long()
940
941 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
942
943 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
944
945 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
946
947 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
948
949 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
950
951 loss = loss.detach().item()
941 val_loss += loss 952 val_loss += loss
942 953
943 if accelerator.sync_gradients: 954 if accelerator.sync_gradients: