summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 80f1d7d..cd2d22b 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -578,6 +578,9 @@ def main():
578 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 578 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
579 token_embeds[token_id] = embeddings 579 token_embeds[token_id] = embeddings
580 580
581 index_fixed_tokens = torch.arange(len(tokenizer))
582 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
583
581 # Freeze vae and unet 584 # Freeze vae and unet
582 freeze_params(vae.parameters()) 585 freeze_params(vae.parameters())
583 freeze_params(unet.parameters()) 586 freeze_params(unet.parameters())
@@ -815,9 +818,6 @@ def main():
815 ) 818 )
816 global_progress_bar.set_description("Total progress") 819 global_progress_bar.set_description("Total progress")
817 820
818 index_fixed_tokens = torch.arange(len(tokenizer))
819 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
820
821 try: 821 try:
822 for epoch in range(num_epochs): 822 for epoch in range(num_epochs):
823 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 823 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")