diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | dreambooth.py | 161 |
2 files changed, 87 insertions, 76 deletions
diff --git a/data/csv.py b/data/csv.py index 4c91ded..df15c5a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
76 | 76 | ||
77 | def prepare_data(self): | 77 | def prepare_data(self): |
78 | metadata = pd.read_json(self.data_file) | 78 | metadata = pd.read_json(self.data_file) |
79 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] | 79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] |
80 | num_images = len(metadata) | 80 | num_images = len(metadata) |
81 | 81 | ||
82 | valid_set_size = int(num_images * 0.2) | 82 | valid_set_size = int(num_images * 0.2) |
diff --git a/dreambooth.py b/dreambooth.py index 1539e81..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -810,75 +810,6 @@ def main(): | |||
810 | ) | 810 | ) |
811 | global_progress_bar.set_description("Total progress") | 811 | global_progress_bar.set_description("Total progress") |
812 | 812 | ||
813 | def run_step(batch, train=False, class_images=False): | ||
814 | # Convert images to latent space | ||
815 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
816 | latents = latents * 0.18215 | ||
817 | |||
818 | # Sample noise that we'll add to the latents | ||
819 | noise = torch.randn_like(latents) | ||
820 | bsz = latents.shape[0] | ||
821 | # Sample a random timestep for each image | ||
822 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | ||
823 | timesteps = timesteps.long() | ||
824 | |||
825 | # Add noise to the latents according to the noise magnitude at each timestep | ||
826 | # (this is the forward diffusion process) | ||
827 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
828 | |||
829 | # Get the text embedding for conditioning | ||
830 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
831 | |||
832 | # Predict the noise residual | ||
833 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
834 | |||
835 | if class_images: | ||
836 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
837 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
838 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
839 | |||
840 | # Compute instance loss | ||
841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | ||
842 | |||
843 | # Compute prior loss | ||
844 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | ||
845 | |||
846 | # Add the prior loss to the instance loss. | ||
847 | loss = loss + args.prior_loss_weight * prior_loss | ||
848 | else: | ||
849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
850 | |||
851 | if train: | ||
852 | accelerator.backward(loss) | ||
853 | |||
854 | if args.initializer_token is not None: | ||
855 | # Keep the token embeddings fixed except the newly added | ||
856 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
857 | if accelerator.num_processes > 1: | ||
858 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
859 | else: | ||
860 | token_embeds = text_encoder.get_input_embeddings().weight | ||
861 | |||
862 | # Get the index for tokens that we want to freeze | ||
863 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
864 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
865 | |||
866 | if accelerator.sync_gradients: | ||
867 | params_to_clip = ( | ||
868 | unet.parameters() | ||
869 | if args.initializer_token is not None | ||
870 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
871 | ) | ||
872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
873 | |||
874 | optimizer.step() | ||
875 | if not accelerator.optimizer_step_was_skipped: | ||
876 | lr_scheduler.step() | ||
877 | optimizer.zero_grad(set_to_none=True) | ||
878 | |||
879 | loss = loss.detach().item() | ||
880 | return loss | ||
881 | |||
882 | try: | 813 | try: |
883 | for epoch in range(num_epochs): | 814 | for epoch in range(num_epochs): |
884 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 815 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -892,11 +823,72 @@ def main(): | |||
892 | 823 | ||
893 | for step, batch in enumerate(train_dataloader): | 824 | for step, batch in enumerate(train_dataloader): |
894 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 825 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
895 | loss = run_step( | 826 | # Convert images to latent space |
896 | batch, | 827 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
897 | train=True, | 828 | latents = latents * 0.18215 |
898 | class_images=args.num_class_images != 0 | 829 | |
899 | ) | 830 | # Sample noise that we'll add to the latents |
831 | noise = torch.randn_like(latents) | ||
832 | bsz = latents.shape[0] | ||
833 | # Sample a random timestep for each image | ||
834 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
835 | (bsz,), device=latents.device) | ||
836 | timesteps = timesteps.long() | ||
837 | |||
838 | # Add noise to the latents according to the noise magnitude at each timestep | ||
839 | # (this is the forward diffusion process) | ||
840 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
841 | |||
842 | # Get the text embedding for conditioning | ||
843 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
844 | |||
845 | # Predict the noise residual | ||
846 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
847 | |||
848 | if args.num_class_images != 0: | ||
849 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
850 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
851 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
852 | |||
853 | # Compute instance loss | ||
854 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | ||
855 | |||
856 | # Compute prior loss | ||
857 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | ||
858 | |||
859 | # Add the prior loss to the instance loss. | ||
860 | loss = loss + args.prior_loss_weight * prior_loss | ||
861 | else: | ||
862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
863 | |||
864 | accelerator.backward(loss) | ||
865 | |||
866 | if args.initializer_token is not None: | ||
867 | # Keep the token embeddings fixed except the newly added | ||
868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
869 | if accelerator.num_processes > 1: | ||
870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
871 | else: | ||
872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
873 | |||
874 | # Get the index for tokens that we want to freeze | ||
875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
877 | |||
878 | if accelerator.sync_gradients: | ||
879 | params_to_clip = ( | ||
880 | unet.parameters() | ||
881 | if args.initializer_token is not None | ||
882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
883 | ) | ||
884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
885 | |||
886 | optimizer.step() | ||
887 | if not accelerator.optimizer_step_was_skipped: | ||
888 | lr_scheduler.step() | ||
889 | optimizer.zero_grad(set_to_none=True) | ||
890 | |||
891 | loss = loss.detach().item() | ||
900 | train_loss += loss | 892 | train_loss += loss |
901 | 893 | ||
902 | # Checks if the accelerator has performed an optimization step behind the scenes | 894 | # Checks if the accelerator has performed an optimization step behind the scenes |
@@ -937,7 +929,26 @@ def main(): | |||
937 | 929 | ||
938 | with torch.inference_mode(): | 930 | with torch.inference_mode(): |
939 | for step, batch in enumerate(val_dataloader): | 931 | for step, batch in enumerate(val_dataloader): |
940 | loss = run_step(batch) | 932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
933 | latents = latents * 0.18215 | ||
934 | |||
935 | noise = torch.randn_like(latents) | ||
936 | bsz = latents.shape[0] | ||
937 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
938 | (bsz,), device=latents.device) | ||
939 | timesteps = timesteps.long() | ||
940 | |||
941 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
942 | |||
943 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
944 | |||
945 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
946 | |||
947 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
948 | |||
949 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
950 | |||
951 | loss = loss.detach().item() | ||
941 | val_loss += loss | 952 | val_loss += loss |
942 | 953 | ||
943 | if accelerator.sync_gradients: | 954 | if accelerator.sync_gradients: |