diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 20 |
1 files changed, 1 insertions, 19 deletions
diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -112,12 +112,6 @@ def parse_args(): | |||
| 112 | help="Tokens to create an alias for.", | 112 | help="Tokens to create an alias for.", |
| 113 | ) | 113 | ) |
| 114 | parser.add_argument( | 114 | parser.add_argument( |
| 115 | "--inverted_initializer_tokens", | ||
| 116 | type=str, | ||
| 117 | nargs="*", | ||
| 118 | help="A token to use as initializer word.", | ||
| 119 | ) | ||
| 120 | parser.add_argument( | ||
| 121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 115 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 122 | ) | 116 | ) |
| 123 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -545,18 +539,6 @@ def parse_args(): | |||
| 545 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 539 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
| 546 | ) | 540 | ) |
| 547 | 541 | ||
| 548 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 549 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 550 | args.placeholder_tokens | ||
| 551 | ) | ||
| 552 | |||
| 553 | if ( | ||
| 554 | isinstance(args.inverted_initializer_tokens, list) | ||
| 555 | and len(args.inverted_initializer_tokens) != 0 | ||
| 556 | ): | ||
| 557 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 558 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 559 | |||
| 560 | if isinstance(args.num_vectors, int): | 542 | if isinstance(args.num_vectors, int): |
| 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 543 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 562 | 544 | ||
| @@ -872,7 +854,7 @@ def main(): | |||
| 872 | 854 | ||
| 873 | optimizer = create_optimizer( | 855 | optimizer = create_optimizer( |
| 874 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 875 | lr=learning_rate, | 857 | lr=args.learning_rate, |
| 876 | ) | 858 | ) |
| 877 | 859 | ||
| 878 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
