diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index 9ae8d1b..e4fd464 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -588,14 +588,6 @@ def main(): | |||
588 | unet.enable_gradient_checkpointing() | 588 | unet.enable_gradient_checkpointing() |
589 | text_encoder.gradient_checkpointing_enable() | 589 | text_encoder.gradient_checkpointing_enable() |
590 | 590 | ||
591 | if args.embeddings_dir is not None: | ||
592 | embeddings_dir = Path(args.embeddings_dir) | ||
593 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
594 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
595 | |||
596 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
597 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
598 | |||
599 | if len(args.alias_tokens) != 0: | 591 | if len(args.alias_tokens) != 0: |
600 | alias_placeholder_tokens = args.alias_tokens[::2] | 592 | alias_placeholder_tokens = args.alias_tokens[::2] |
601 | alias_initializer_tokens = args.alias_tokens[1::2] | 593 | alias_initializer_tokens = args.alias_tokens[1::2] |
@@ -609,6 +601,14 @@ def main(): | |||
609 | embeddings.persist() | 601 | embeddings.persist() |
610 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 602 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
611 | 603 | ||
604 | if args.embeddings_dir is not None: | ||
605 | embeddings_dir = Path(args.embeddings_dir) | ||
606 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
607 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
608 | |||
609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
611 | |||
612 | if args.scale_lr: | 612 | if args.scale_lr: |
613 | args.learning_rate = ( | 613 | args.learning_rate = ( |
614 | args.learning_rate * args.gradient_accumulation_steps * | 614 | args.learning_rate * args.gradient_accumulation_steps * |