diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-27 13:19:05 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-27 13:19:05 +0200 |
| commit | c40170386fd055f715db90886f0ac0da5c575bd9 (patch) | |
| tree | 063d4d89d179750a241e8e652d77ea8586fd2ac7 /train_ti.py | |
| parent | Fix TI (diff) | |
| download | textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.gz textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.tar.bz2 textual-inversion-diff-c40170386fd055f715db90886f0ac0da5c575bd9.zip | |
Fix TI
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index 9ae8d1b..e4fd464 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -588,14 +588,6 @@ def main(): | |||
| 588 | unet.enable_gradient_checkpointing() | 588 | unet.enable_gradient_checkpointing() |
| 589 | text_encoder.gradient_checkpointing_enable() | 589 | text_encoder.gradient_checkpointing_enable() |
| 590 | 590 | ||
| 591 | if args.embeddings_dir is not None: | ||
| 592 | embeddings_dir = Path(args.embeddings_dir) | ||
| 593 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 594 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 595 | |||
| 596 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 597 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
| 598 | |||
| 599 | if len(args.alias_tokens) != 0: | 591 | if len(args.alias_tokens) != 0: |
| 600 | alias_placeholder_tokens = args.alias_tokens[::2] | 592 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 601 | alias_initializer_tokens = args.alias_tokens[1::2] | 593 | alias_initializer_tokens = args.alias_tokens[1::2] |
| @@ -609,6 +601,14 @@ def main(): | |||
| 609 | embeddings.persist() | 601 | embeddings.persist() |
| 610 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 602 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
| 611 | 603 | ||
| 604 | if args.embeddings_dir is not None: | ||
| 605 | embeddings_dir = Path(args.embeddings_dir) | ||
| 606 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 607 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 608 | |||
| 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
| 611 | |||
| 612 | if args.scale_lr: | 612 | if args.scale_lr: |
| 613 | args.learning_rate = ( | 613 | args.learning_rate = ( |
| 614 | args.learning_rate * args.gradient_accumulation_steps * | 614 | args.learning_rate * args.gradient_accumulation_steps * |
