summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-01 22:04:10 +0100
committerVolpeon <git@volpeon.ink>2022-12-01 22:04:10 +0100
commit30e072fc795ea36eb92ba25e433b557ff650e35a (patch)
tree9bf142d8c1ec6fb0ed15a0358ed3fa22c7c44f6e /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.tar.gz
textual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.tar.bz2
textual-inversion-diff-30e072fc795ea36eb92ba25e433b557ff650e35a.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1ead6dd..f3f722e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -651,6 +651,9 @@ def main():
651 text_encoder.text_model.embeddings.position_embedding.parameters(), 651 text_encoder.text_model.embeddings.position_embedding.parameters(),
652 )) 652 ))
653 653
654 index_fixed_tokens = torch.arange(len(tokenizer))
655 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
656
654 prompt_processor = PromptProcessor(tokenizer, text_encoder) 657 prompt_processor = PromptProcessor(tokenizer, text_encoder)
655 658
656 if args.scale_lr: 659 if args.scale_lr:
@@ -899,9 +902,6 @@ def main():
899 ) 902 )
900 global_progress_bar.set_description("Total progress") 903 global_progress_bar.set_description("Total progress")
901 904
902 index_fixed_tokens = torch.arange(len(tokenizer))
903 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
904
905 try: 905 try:
906 for epoch in range(num_epochs): 906 for epoch in range(num_epochs):
907 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 907 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")