diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 20 |
1 files changed, 1 insertions, 19 deletions
diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -112,12 +112,6 @@ def parse_args(): | |||
112 | help="Tokens to create an alias for.", | 112 | help="Tokens to create an alias for.", |
113 | ) | 113 | ) |
114 | parser.add_argument( | 114 | parser.add_argument( |
115 | "--inverted_initializer_tokens", | ||
116 | type=str, | ||
117 | nargs="*", | ||
118 | help="A token to use as initializer word.", | ||
119 | ) | ||
120 | parser.add_argument( | ||
121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 115 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
122 | ) | 116 | ) |
123 | parser.add_argument( | 117 | parser.add_argument( |
@@ -545,18 +539,6 @@ def parse_args(): | |||
545 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 539 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
546 | ) | 540 | ) |
547 | 541 | ||
548 | if isinstance(args.inverted_initializer_tokens, str): | ||
549 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
550 | args.placeholder_tokens | ||
551 | ) | ||
552 | |||
553 | if ( | ||
554 | isinstance(args.inverted_initializer_tokens, list) | ||
555 | and len(args.inverted_initializer_tokens) != 0 | ||
556 | ): | ||
557 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
558 | args.initializer_tokens += args.inverted_initializer_tokens | ||
559 | |||
560 | if isinstance(args.num_vectors, int): | 542 | if isinstance(args.num_vectors, int): |
561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 543 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
562 | 544 | ||
@@ -872,7 +854,7 @@ def main(): | |||
872 | 854 | ||
873 | optimizer = create_optimizer( | 855 | optimizer = create_optimizer( |
874 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
875 | lr=learning_rate, | 857 | lr=args.learning_rate, |
876 | ) | 858 | ) |
877 | 859 | ||
878 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |