diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-17 15:18:20 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-17 15:18:20 +0100 |
| commit | 8abbd633d8ee7500058b2f1f69a6d6611b5a4450 (patch) | |
| tree | f60d6e384966ba05354b30f08a32a38279b56165 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.gz textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.bz2 textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.zip | |
Test: https://arxiv.org/pdf/2303.09556.pdf
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/train_ti.py b/train_ti.py index 81938c8..fd23517 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -62,12 +62,6 @@ def parse_args(): | |||
| 62 | help="The name of the current project.", | 62 | help="The name of the current project.", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--skip_first", | ||
| 66 | type=int, | ||
| 67 | default=0, | ||
| 68 | help="Tokens to skip training for.", | ||
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--placeholder_tokens", | 65 | "--placeholder_tokens", |
| 72 | type=str, | 66 | type=str, |
| 73 | nargs='*', | 67 | nargs='*', |
| @@ -80,6 +74,13 @@ def parse_args(): | |||
| 80 | help="A token to use as initializer word." | 74 | help="A token to use as initializer word." |
| 81 | ) | 75 | ) |
| 82 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--alias_tokens", | ||
| 78 | type=str, | ||
| 79 | nargs='*', | ||
| 80 | default=[], | ||
| 81 | help="Tokens to create an alias for." | ||
| 82 | ) | ||
| 83 | parser.add_argument( | ||
| 83 | "--num_vectors", | 84 | "--num_vectors", |
| 84 | type=int, | 85 | type=int, |
| 85 | nargs='*', | 86 | nargs='*', |
| @@ -420,7 +421,7 @@ def parse_args(): | |||
| 420 | ) | 421 | ) |
| 421 | parser.add_argument( | 422 | parser.add_argument( |
| 422 | "--emb_decay", | 423 | "--emb_decay", |
| 423 | default=1e-2, | 424 | default=1e2, |
| 424 | type=float, | 425 | type=float, |
| 425 | help="Embedding decay factor." | 426 | help="Embedding decay factor." |
| 426 | ) | 427 | ) |
| @@ -482,6 +483,9 @@ def parse_args(): | |||
| 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 483 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
| 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 484 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 484 | 485 | ||
| 486 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
| 487 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
| 488 | |||
| 485 | if args.sequential: | 489 | if args.sequential: |
| 486 | if isinstance(args.train_data_template, str): | 490 | if isinstance(args.train_data_template, str): |
| 487 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 491 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
| @@ -543,8 +547,8 @@ def main(): | |||
| 543 | tokenizer.set_dropout(args.vector_dropout) | 547 | tokenizer.set_dropout(args.vector_dropout) |
| 544 | 548 | ||
| 545 | vae.enable_slicing() | 549 | vae.enable_slicing() |
| 546 | vae.set_use_memory_efficient_attention_xformers(True) | 550 | # vae.set_use_memory_efficient_attention_xformers(True) |
| 547 | unet.enable_xformers_memory_efficient_attention() | 551 | # unet.enable_xformers_memory_efficient_attention() |
| 548 | # unet = torch.compile(unet) | 552 | # unet = torch.compile(unet) |
| 549 | 553 | ||
| 550 | if args.gradient_checkpointing: | 554 | if args.gradient_checkpointing: |
| @@ -559,6 +563,19 @@ def main(): | |||
| 559 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 563 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 560 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 564 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 561 | 565 | ||
| 566 | if len(args.alias_tokens) != 0: | ||
| 567 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
| 568 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
| 569 | |||
| 570 | added_tokens, added_ids = add_placeholder_tokens( | ||
| 571 | tokenizer=tokenizer, | ||
| 572 | embeddings=embeddings, | ||
| 573 | placeholder_tokens=alias_placeholder_tokens, | ||
| 574 | initializer_tokens=alias_initializer_tokens | ||
| 575 | ) | ||
| 576 | embeddings.persist() | ||
| 577 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
| 578 | |||
| 562 | if args.scale_lr: | 579 | if args.scale_lr: |
| 563 | args.learning_rate = ( | 580 | args.learning_rate = ( |
| 564 | args.learning_rate * args.gradient_accumulation_steps * | 581 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -633,9 +650,6 @@ def main(): | |||
| 633 | ) | 650 | ) |
| 634 | 651 | ||
| 635 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 652 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 636 | if i < args.skip_first: | ||
| 637 | return | ||
| 638 | |||
| 639 | if len(placeholder_tokens) == 1: | 653 | if len(placeholder_tokens) == 1: |
| 640 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 654 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
| 641 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 655 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
