summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
committerVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
commit7c02c2fe68da2411623f0a11c1187ccf0f7743d8 (patch)
tree106eddc16374eaa80966782168ab41c6c191145e
parentUpdate (diff)
downloadtextual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.gz
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.bz2
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.zip
Update
-rw-r--r--dreambooth.py17
-rw-r--r--environment.yaml2
-rw-r--r--infer.py1
-rw-r--r--textual_inversion.py9
4 files changed, 16 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 31dbea2..1ead6dd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -550,11 +550,11 @@ class Checkpointer:
550def main(): 550def main():
551 args = parse_args() 551 args = parse_args()
552 552
553 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 553 # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
554 raise ValueError( 554 # raise ValueError(
555 "Gradient accumulation is not supported when training the text encoder in distributed training. " 555 # "Gradient accumulation is not supported when training the text encoder in distributed training. "
556 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 556 # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
557 ) 557 # )
558 558
559 instance_identifier = args.instance_identifier 559 instance_identifier = args.instance_identifier
560 560
@@ -899,6 +899,9 @@ def main():
899 ) 899 )
900 global_progress_bar.set_description("Total progress") 900 global_progress_bar.set_description("Total progress")
901 901
902 index_fixed_tokens = torch.arange(len(tokenizer))
903 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
904
902 try: 905 try:
903 for epoch in range(num_epochs): 906 for epoch in range(num_epochs):
904 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 907 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -910,7 +913,7 @@ def main():
910 sample_checkpoint = False 913 sample_checkpoint = False
911 914
912 for step, batch in enumerate(train_dataloader): 915 for step, batch in enumerate(train_dataloader):
913 with accelerator.accumulate(unet): 916 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
914 # Convert images to latent space 917 # Convert images to latent space
915 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 918 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
916 latents = latents * 0.18215 919 latents = latents * 0.18215
@@ -967,8 +970,6 @@ def main():
967 else: 970 else:
968 token_embeds = text_encoder.get_input_embeddings().weight 971 token_embeds = text_encoder.get_input_embeddings().weight
969 972
970 # Get the index for tokens that we want to freeze
971 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
972 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] 973 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
973 974
974 if accelerator.sync_gradients: 975 if accelerator.sync_gradients:
diff --git a/environment.yaml b/environment.yaml
index 4972ebd..24693d5 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,7 +11,7 @@ dependencies:
11 - pytorch=1.12.1 11 - pytorch=1.12.1
12 - torchvision=0.13.1 12 - torchvision=0.13.1
13 - pandas=1.4.3 13 - pandas=1.4.3
14 - xformers=0.0.15.dev337 14 - xformers=0.0.15.dev344
15 - pip: 15 - pip:
16 - -e . 16 - -e .
17 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 17 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
diff --git a/infer.py b/infer.py
index eabeb5e..75c8621 100644
--- a/infer.py
+++ b/infer.py
@@ -219,6 +219,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype):
219 scheduler=scheduler, 219 scheduler=scheduler,
220 ) 220 )
221 pipeline.enable_xformers_memory_efficient_attention() 221 pipeline.enable_xformers_memory_efficient_attention()
222 pipeline.enable_vae_slicing()
222 pipeline.to("cuda") 223 pipeline.to("cuda")
223 224
224 print("Pipeline loaded.") 225 print("Pipeline loaded.")
diff --git a/textual_inversion.py b/textual_inversion.py
index d6be522..80f1d7d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -545,6 +545,7 @@ def main():
545 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 545 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
546 args.pretrained_model_name_or_path, subfolder='scheduler') 546 args.pretrained_model_name_or_path, subfolder='scheduler')
547 547
548 vae.enable_slicing()
548 unet.set_use_memory_efficient_attention_xformers(True) 549 unet.set_use_memory_efficient_attention_xformers(True)
549 550
550 if args.gradient_checkpointing: 551 if args.gradient_checkpointing:
@@ -814,6 +815,9 @@ def main():
814 ) 815 )
815 global_progress_bar.set_description("Total progress") 816 global_progress_bar.set_description("Total progress")
816 817
818 index_fixed_tokens = torch.arange(len(tokenizer))
819 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
820
817 try: 821 try:
818 for epoch in range(num_epochs): 822 for epoch in range(num_epochs):
819 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 823 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -827,7 +831,7 @@ def main():
827 for step, batch in enumerate(train_dataloader): 831 for step, batch in enumerate(train_dataloader):
828 with accelerator.accumulate(text_encoder): 832 with accelerator.accumulate(text_encoder):
829 # Convert images to latent space 833 # Convert images to latent space
830 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 834 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
831 latents = latents * 0.18215 835 latents = latents * 0.18215
832 836
833 # Sample noise that we'll add to the latents 837 # Sample noise that we'll add to the latents
@@ -883,7 +887,6 @@ def main():
883 token_embeds = text_encoder.get_input_embeddings().weight 887 token_embeds = text_encoder.get_input_embeddings().weight
884 888
885 # Get the index for tokens that we want to freeze 889 # Get the index for tokens that we want to freeze
886 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
887 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] 890 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
888 891
889 optimizer.step() 892 optimizer.step()
@@ -927,8 +930,6 @@ def main():
927 930
928 accelerator.wait_for_everyone() 931 accelerator.wait_for_everyone()
929 932
930 print(token_embeds[placeholder_token_id])
931
932 text_encoder.eval() 933 text_encoder.eval()
933 val_loss = 0.0 934 val_loss = 0.0
934 935