summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
committerVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
commit7c02c2fe68da2411623f0a11c1187ccf0f7743d8 (patch)
tree106eddc16374eaa80966782168ab41c6c191145e /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.gz
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.bz2
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index d6be522..80f1d7d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -545,6 +545,7 @@ def main():
545 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 545 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
546 args.pretrained_model_name_or_path, subfolder='scheduler') 546 args.pretrained_model_name_or_path, subfolder='scheduler')
547 547
548 vae.enable_slicing()
548 unet.set_use_memory_efficient_attention_xformers(True) 549 unet.set_use_memory_efficient_attention_xformers(True)
549 550
550 if args.gradient_checkpointing: 551 if args.gradient_checkpointing:
@@ -814,6 +815,9 @@ def main():
814 ) 815 )
815 global_progress_bar.set_description("Total progress") 816 global_progress_bar.set_description("Total progress")
816 817
818 index_fixed_tokens = torch.arange(len(tokenizer))
819 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
820
817 try: 821 try:
818 for epoch in range(num_epochs): 822 for epoch in range(num_epochs):
819 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 823 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -827,7 +831,7 @@ def main():
827 for step, batch in enumerate(train_dataloader): 831 for step, batch in enumerate(train_dataloader):
828 with accelerator.accumulate(text_encoder): 832 with accelerator.accumulate(text_encoder):
829 # Convert images to latent space 833 # Convert images to latent space
830 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 834 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
831 latents = latents * 0.18215 835 latents = latents * 0.18215
832 836
833 # Sample noise that we'll add to the latents 837 # Sample noise that we'll add to the latents
@@ -883,7 +887,6 @@ def main():
883 token_embeds = text_encoder.get_input_embeddings().weight 887 token_embeds = text_encoder.get_input_embeddings().weight
884 888
885 # Get the index for tokens that we want to freeze 889 # Get the index for tokens that we want to freeze
886 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
887 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] 890 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
888 891
889 optimizer.step() 892 optimizer.step()
@@ -927,8 +930,6 @@ def main():
927 930
928 accelerator.wait_for_everyone() 931 accelerator.wait_for_everyone()
929 932
930 print(token_embeds[placeholder_token_id])
931
932 text_encoder.eval() 933 text_encoder.eval()
933 val_loss = 0.0 934 val_loss = 0.0
934 935