summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py20
1 files changed, 1 insertions, 19 deletions
diff --git a/train_ti.py b/train_ti.py
index 1dbd637..8c63493 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -112,12 +112,6 @@ def parse_args():
112 help="Tokens to create an alias for.", 112 help="Tokens to create an alias for.",
113 ) 113 )
114 parser.add_argument( 114 parser.add_argument(
115 "--inverted_initializer_tokens",
116 type=str,
117 nargs="*",
118 help="A token to use as initializer word.",
119 )
120 parser.add_argument(
121 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 115 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
122 ) 116 )
123 parser.add_argument( 117 parser.add_argument(
@@ -545,18 +539,6 @@ def parse_args():
545 "--placeholder_tokens and --initializer_tokens must have the same number of items" 539 "--placeholder_tokens and --initializer_tokens must have the same number of items"
546 ) 540 )
547 541
548 if isinstance(args.inverted_initializer_tokens, str):
549 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
550 args.placeholder_tokens
551 )
552
553 if (
554 isinstance(args.inverted_initializer_tokens, list)
555 and len(args.inverted_initializer_tokens) != 0
556 ):
557 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
558 args.initializer_tokens += args.inverted_initializer_tokens
559
560 if isinstance(args.num_vectors, int): 542 if isinstance(args.num_vectors, int):
561 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 543 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
562 544
@@ -872,7 +854,7 @@ def main():
872 854
873 optimizer = create_optimizer( 855 optimizer = create_optimizer(
874 text_encoder.text_model.embeddings.token_embedding.parameters(), 856 text_encoder.text_model.embeddings.token_embedding.parameters(),
875 lr=learning_rate, 857 lr=args.learning_rate,
876 ) 858 )
877 859
878 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 860 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)