diff options
-rw-r--r-- | dreambooth.py | 17 | ||||
-rw-r--r-- | environment.yaml | 2 | ||||
-rw-r--r-- | infer.py | 1 | ||||
-rw-r--r-- | textual_inversion.py | 9 |
4 files changed, 16 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py index 31dbea2..1ead6dd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -550,11 +550,11 @@ class Checkpointer: | |||
550 | def main(): | 550 | def main(): |
551 | args = parse_args() | 551 | args = parse_args() |
552 | 552 | ||
553 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | 553 | # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: |
554 | raise ValueError( | 554 | # raise ValueError( |
555 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | 555 | # "Gradient accumulation is not supported when training the text encoder in distributed training. " |
556 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 556 | # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
557 | ) | 557 | # ) |
558 | 558 | ||
559 | instance_identifier = args.instance_identifier | 559 | instance_identifier = args.instance_identifier |
560 | 560 | ||
@@ -899,6 +899,9 @@ def main(): | |||
899 | ) | 899 | ) |
900 | global_progress_bar.set_description("Total progress") | 900 | global_progress_bar.set_description("Total progress") |
901 | 901 | ||
902 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
903 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
904 | |||
902 | try: | 905 | try: |
903 | for epoch in range(num_epochs): | 906 | for epoch in range(num_epochs): |
904 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -910,7 +913,7 @@ def main(): | |||
910 | sample_checkpoint = False | 913 | sample_checkpoint = False |
911 | 914 | ||
912 | for step, batch in enumerate(train_dataloader): | 915 | for step, batch in enumerate(train_dataloader): |
913 | with accelerator.accumulate(unet): | 916 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
914 | # Convert images to latent space | 917 | # Convert images to latent space |
915 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 918 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
916 | latents = latents * 0.18215 | 919 | latents = latents * 0.18215 |
@@ -967,8 +970,6 @@ def main(): | |||
967 | else: | 970 | else: |
968 | token_embeds = text_encoder.get_input_embeddings().weight | 971 | token_embeds = text_encoder.get_input_embeddings().weight |
969 | 972 | ||
970 | # Get the index for tokens that we want to freeze | ||
971 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
972 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | 973 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
973 | 974 | ||
974 | if accelerator.sync_gradients: | 975 | if accelerator.sync_gradients: |
diff --git a/environment.yaml b/environment.yaml index 4972ebd..24693d5 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -11,7 +11,7 @@ dependencies: | |||
11 | - pytorch=1.12.1 | 11 | - pytorch=1.12.1 |
12 | - torchvision=0.13.1 | 12 | - torchvision=0.13.1 |
13 | - pandas=1.4.3 | 13 | - pandas=1.4.3 |
14 | - xformers=0.0.15.dev337 | 14 | - xformers=0.0.15.dev344 |
15 | - pip: | 15 | - pip: |
16 | - -e . | 16 | - -e . |
17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | 17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers |
@@ -219,6 +219,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): | |||
219 | scheduler=scheduler, | 219 | scheduler=scheduler, |
220 | ) | 220 | ) |
221 | pipeline.enable_xformers_memory_efficient_attention() | 221 | pipeline.enable_xformers_memory_efficient_attention() |
222 | pipeline.enable_vae_slicing() | ||
222 | pipeline.to("cuda") | 223 | pipeline.to("cuda") |
223 | 224 | ||
224 | print("Pipeline loaded.") | 225 | print("Pipeline loaded.") |
diff --git a/textual_inversion.py b/textual_inversion.py index d6be522..80f1d7d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -545,6 +545,7 @@ def main(): | |||
545 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 545 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 546 | args.pretrained_model_name_or_path, subfolder='scheduler') |
547 | 547 | ||
548 | vae.enable_slicing() | ||
548 | unet.set_use_memory_efficient_attention_xformers(True) | 549 | unet.set_use_memory_efficient_attention_xformers(True) |
549 | 550 | ||
550 | if args.gradient_checkpointing: | 551 | if args.gradient_checkpointing: |
@@ -814,6 +815,9 @@ def main(): | |||
814 | ) | 815 | ) |
815 | global_progress_bar.set_description("Total progress") | 816 | global_progress_bar.set_description("Total progress") |
816 | 817 | ||
818 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
819 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
820 | |||
817 | try: | 821 | try: |
818 | for epoch in range(num_epochs): | 822 | for epoch in range(num_epochs): |
819 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -827,7 +831,7 @@ def main(): | |||
827 | for step, batch in enumerate(train_dataloader): | 831 | for step, batch in enumerate(train_dataloader): |
828 | with accelerator.accumulate(text_encoder): | 832 | with accelerator.accumulate(text_encoder): |
829 | # Convert images to latent space | 833 | # Convert images to latent space |
830 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 834 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
831 | latents = latents * 0.18215 | 835 | latents = latents * 0.18215 |
832 | 836 | ||
833 | # Sample noise that we'll add to the latents | 837 | # Sample noise that we'll add to the latents |
@@ -883,7 +887,6 @@ def main(): | |||
883 | token_embeds = text_encoder.get_input_embeddings().weight | 887 | token_embeds = text_encoder.get_input_embeddings().weight |
884 | 888 | ||
885 | # Get the index for tokens that we want to freeze | 889 | # Get the index for tokens that we want to freeze |
886 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
887 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | 890 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
888 | 891 | ||
889 | optimizer.step() | 892 | optimizer.step() |
@@ -927,8 +930,6 @@ def main(): | |||
927 | 930 | ||
928 | accelerator.wait_for_everyone() | 931 | accelerator.wait_for_everyone() |
929 | 932 | ||
930 | print(token_embeds[placeholder_token_id]) | ||
931 | |||
932 | text_encoder.eval() | 933 | text_encoder.eval() |
933 | val_loss = 0.0 | 934 | val_loss = 0.0 |
934 | 935 | ||