diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 80f1d7d..cd2d22b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -578,6 +578,9 @@ def main(): | |||
578 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 578 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
579 | token_embeds[token_id] = embeddings | 579 | token_embeds[token_id] = embeddings |
580 | 580 | ||
581 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
582 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
583 | |||
581 | # Freeze vae and unet | 584 | # Freeze vae and unet |
582 | freeze_params(vae.parameters()) | 585 | freeze_params(vae.parameters()) |
583 | freeze_params(unet.parameters()) | 586 | freeze_params(unet.parameters()) |
@@ -815,9 +818,6 @@ def main(): | |||
815 | ) | 818 | ) |
816 | global_progress_bar.set_description("Total progress") | 819 | global_progress_bar.set_description("Total progress") |
817 | 820 | ||
818 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
819 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
820 | |||
821 | try: | 821 | try: |
822 | for epoch in range(num_epochs): | 822 | for epoch in range(num_epochs): |
823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |