diff options
author | Volpeon <git@volpeon.ink> | 2022-12-01 22:04:10 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-01 22:04:10 +0100 |
commit | 30e072fc795ea36eb92ba25e433b557ff650e35a (patch) | |
tree | 9bf142d8c1ec6fb0ed15a0358ed3fa22c7c44f6e /dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.tar.gz textual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.tar.bz2 textual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.zip |
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1ead6dd..f3f722e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -651,6 +651,9 @@ def main(): | |||
651 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 651 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
652 | )) | 652 | )) |
653 | 653 | ||
654 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
655 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
656 | |||
654 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 657 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
655 | 658 | ||
656 | if args.scale_lr: | 659 | if args.scale_lr: |
@@ -899,9 +902,6 @@ def main(): | |||
899 | ) | 902 | ) |
900 | global_progress_bar.set_description("Total progress") | 903 | global_progress_bar.set_description("Total progress") |
901 | 904 | ||
902 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
903 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
904 | |||
905 | try: | 905 | try: |
906 | for epoch in range(num_epochs): | 906 | for epoch in range(num_epochs): |
907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |