diff options
author | Volpeon <git@volpeon.ink> | 2022-10-22 16:56:10 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-22 16:56:10 +0200 |
commit | 46b1eda6d1c7db552ce5c577bed101c61f09d55b (patch) | |
tree | 80840bbb5a57238b3dabae1d7fa2588a69f79dd9 | |
parent | Training update (diff) | |
download | textual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.tar.gz textual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.tar.bz2 textual-inversion-diff-46b1eda6d1c7db552ce5c577bed101c61f09d55b.zip |
Revert lat; fix skip attribute in dataset
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | dreambooth.py | 153 |
2 files changed, 83 insertions, 72 deletions
diff --git a/data/csv.py b/data/csv.py index 4c91ded..df15c5a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
76 | 76 | ||
77 | def prepare_data(self): | 77 | def prepare_data(self): |
78 | metadata = pd.read_json(self.data_file) | 78 | metadata = pd.read_json(self.data_file) |
79 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] | 79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] |
80 | num_images = len(metadata) | 80 | num_images = len(metadata) |
81 | 81 | ||
82 | valid_set_size = int(num_images * 0.2) | 82 | valid_set_size = int(num_images * 0.2) |
diff --git a/dreambooth.py b/dreambooth.py index 1539e81..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -810,93 +810,85 @@ def main(): | |||
810 | ) | 810 | ) |
811 | global_progress_bar.set_description("Total progress") | 811 | global_progress_bar.set_description("Total progress") |
812 | 812 | ||
813 | def run_step(batch, train=False, class_images=False): | 813 | try: |
814 | # Convert images to latent space | 814 | for epoch in range(num_epochs): |
815 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 815 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
816 | latents = latents * 0.18215 | 816 | local_progress_bar.reset() |
817 | |||
818 | # Sample noise that we'll add to the latents | ||
819 | noise = torch.randn_like(latents) | ||
820 | bsz = latents.shape[0] | ||
821 | # Sample a random timestep for each image | ||
822 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | ||
823 | timesteps = timesteps.long() | ||
824 | 817 | ||
825 | # Add noise to the latents according to the noise magnitude at each timestep | 818 | unet.train() |
826 | # (this is the forward diffusion process) | 819 | text_encoder.train() |
827 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 820 | train_loss = 0.0 |
828 | 821 | ||
829 | # Get the text embedding for conditioning | 822 | sample_checkpoint = False |
830 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
831 | 823 | ||
832 | # Predict the noise residual | 824 | for step, batch in enumerate(train_dataloader): |
833 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 825 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
826 | # Convert images to latent space | ||
827 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
828 | latents = latents * 0.18215 | ||
834 | 829 | ||
835 | if class_images: | 830 | # Sample noise that we'll add to the latents |
836 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 831 | noise = torch.randn_like(latents) |
837 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 832 | bsz = latents.shape[0] |
838 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 833 | # Sample a random timestep for each image |
834 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
835 | (bsz,), device=latents.device) | ||
836 | timesteps = timesteps.long() | ||
839 | 837 | ||
840 | # Compute instance loss | 838 | # Add noise to the latents according to the noise magnitude at each timestep |
841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 839 | # (this is the forward diffusion process) |
840 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
842 | 841 | ||
843 | # Compute prior loss | 842 | # Get the text embedding for conditioning |
844 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 843 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
845 | 844 | ||
846 | # Add the prior loss to the instance loss. | 845 | # Predict the noise residual |
847 | loss = loss + args.prior_loss_weight * prior_loss | 846 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
848 | else: | ||
849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
850 | 847 | ||
851 | if train: | 848 | if args.num_class_images != 0: |
852 | accelerator.backward(loss) | 849 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
850 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
851 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
853 | 852 | ||
854 | if args.initializer_token is not None: | 853 | # Compute instance loss |
855 | # Keep the token embeddings fixed except the newly added | 854 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
856 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
857 | if accelerator.num_processes > 1: | ||
858 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
859 | else: | ||
860 | token_embeds = text_encoder.get_input_embeddings().weight | ||
861 | 855 | ||
862 | # Get the index for tokens that we want to freeze | 856 | # Compute prior loss |
863 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | 857 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
864 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
865 | 858 | ||
866 | if accelerator.sync_gradients: | 859 | # Add the prior loss to the instance loss. |
867 | params_to_clip = ( | 860 | loss = loss + args.prior_loss_weight * prior_loss |
868 | unet.parameters() | 861 | else: |
869 | if args.initializer_token is not None | 862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
870 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
871 | ) | ||
872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
873 | 863 | ||
874 | optimizer.step() | 864 | accelerator.backward(loss) |
875 | if not accelerator.optimizer_step_was_skipped: | ||
876 | lr_scheduler.step() | ||
877 | optimizer.zero_grad(set_to_none=True) | ||
878 | 865 | ||
879 | loss = loss.detach().item() | 866 | if args.initializer_token is not None: |
880 | return loss | 867 | # Keep the token embeddings fixed except the newly added |
868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
869 | if accelerator.num_processes > 1: | ||
870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
871 | else: | ||
872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
881 | 873 | ||
882 | try: | 874 | # Get the index for tokens that we want to freeze |
883 | for epoch in range(num_epochs): | 875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
884 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
885 | local_progress_bar.reset() | ||
886 | 877 | ||
887 | unet.train() | 878 | if accelerator.sync_gradients: |
888 | text_encoder.train() | 879 | params_to_clip = ( |
889 | train_loss = 0.0 | 880 | unet.parameters() |
881 | if args.initializer_token is not None | ||
882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
883 | ) | ||
884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
890 | 885 | ||
891 | sample_checkpoint = False | 886 | optimizer.step() |
887 | if not accelerator.optimizer_step_was_skipped: | ||
888 | lr_scheduler.step() | ||
889 | optimizer.zero_grad(set_to_none=True) | ||
892 | 890 | ||
893 | for step, batch in enumerate(train_dataloader): | 891 | loss = loss.detach().item() |
894 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | ||
895 | loss = run_step( | ||
896 | batch, | ||
897 | train=True, | ||
898 | class_images=args.num_class_images != 0 | ||
899 | ) | ||
900 | train_loss += loss | 892 | train_loss += loss |
901 | 893 | ||
902 | # Checks if the accelerator has performed an optimization step behind the scenes | 894 | # Checks if the accelerator has performed an optimization step behind the scenes |
@@ -937,7 +929,26 @@ def main(): | |||
937 | 929 | ||
938 | with torch.inference_mode(): | 930 | with torch.inference_mode(): |
939 | for step, batch in enumerate(val_dataloader): | 931 | for step, batch in enumerate(val_dataloader): |
940 | loss = run_step(batch) | 932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
933 | latents = latents * 0.18215 | ||
934 | |||
935 | noise = torch.randn_like(latents) | ||
936 | bsz = latents.shape[0] | ||
937 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
938 | (bsz,), device=latents.device) | ||
939 | timesteps = timesteps.long() | ||
940 | |||
941 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
942 | |||
943 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
944 | |||
945 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
946 | |||
947 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
948 | |||
949 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
950 | |||
951 | loss = loss.detach().item() | ||
941 | val_loss += loss | 952 | val_loss += loss |
942 | 953 | ||
943 | if accelerator.sync_gradients: | 954 | if accelerator.sync_gradients: |