summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 13:19:05 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 13:19:05 +0200
commitc40170386fd055f715db90886f0ac0da5c575bd9 (patch)
tree063d4d89d179750a241e8e652d77ea8586fd2ac7 /train_ti.py
parentFix TI (diff)
downloadtextual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.gz
textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.bz2
textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.zip
Fix TI
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index 9ae8d1b..e4fd464 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -588,14 +588,6 @@ def main():
588 unet.enable_gradient_checkpointing() 588 unet.enable_gradient_checkpointing()
589 text_encoder.gradient_checkpointing_enable() 589 text_encoder.gradient_checkpointing_enable()
590 590
591 if args.embeddings_dir is not None:
592 embeddings_dir = Path(args.embeddings_dir)
593 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
594 raise ValueError("--embeddings_dir must point to an existing directory")
595
596 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
597 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
598
599 if len(args.alias_tokens) != 0: 591 if len(args.alias_tokens) != 0:
600 alias_placeholder_tokens = args.alias_tokens[::2] 592 alias_placeholder_tokens = args.alias_tokens[::2]
601 alias_initializer_tokens = args.alias_tokens[1::2] 593 alias_initializer_tokens = args.alias_tokens[1::2]
@@ -609,6 +601,14 @@ def main():
609 embeddings.persist() 601 embeddings.persist()
610 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 602 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
611 603
604 if args.embeddings_dir is not None:
605 embeddings_dir = Path(args.embeddings_dir)
606 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
607 raise ValueError("--embeddings_dir must point to an existing directory")
608
609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
611
612 if args.scale_lr: 612 if args.scale_lr:
613 args.learning_rate = ( 613 args.learning_rate = (
614 args.learning_rate * args.gradient_accumulation_steps * 614 args.learning_rate * args.gradient_accumulation_steps *