diff options
| -rw-r--r-- | dreambooth.py | 6 | ||||
| -rw-r--r-- | textual_inversion.py | 6 |
2 files changed, 6 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1ead6dd..f3f722e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -651,6 +651,9 @@ def main(): | |||
| 651 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 651 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 652 | )) | 652 | )) |
| 653 | 653 | ||
| 654 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 655 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 656 | |||
| 654 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 657 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 655 | 658 | ||
| 656 | if args.scale_lr: | 659 | if args.scale_lr: |
| @@ -899,9 +902,6 @@ def main(): | |||
| 899 | ) | 902 | ) |
| 900 | global_progress_bar.set_description("Total progress") | 903 | global_progress_bar.set_description("Total progress") |
| 901 | 904 | ||
| 902 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 903 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 904 | |||
| 905 | try: | 905 | try: |
| 906 | for epoch in range(num_epochs): | 906 | for epoch in range(num_epochs): |
| 907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
diff --git a/textual_inversion.py b/textual_inversion.py index 80f1d7d..cd2d22b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -578,6 +578,9 @@ def main(): | |||
| 578 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 578 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
| 579 | token_embeds[token_id] = embeddings | 579 | token_embeds[token_id] = embeddings |
| 580 | 580 | ||
| 581 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 582 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 583 | |||
| 581 | # Freeze vae and unet | 584 | # Freeze vae and unet |
| 582 | freeze_params(vae.parameters()) | 585 | freeze_params(vae.parameters()) |
| 583 | freeze_params(unet.parameters()) | 586 | freeze_params(unet.parameters()) |
| @@ -815,9 +818,6 @@ def main(): | |||
| 815 | ) | 818 | ) |
| 816 | global_progress_bar.set_description("Total progress") | 819 | global_progress_bar.set_description("Total progress") |
| 817 | 820 | ||
| 818 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 819 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 820 | |||
| 821 | try: | 821 | try: |
| 822 | for epoch in range(num_epochs): | 822 | for epoch in range(num_epochs): |
| 823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
