From 950f1f6bcbb1a767170cea590b828d8e3cdae882 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Jun 2023 06:48:38 +0200 Subject: Update --- train_ti.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 1dbd637..8c63493 100644 --- a/train_ti.py +++ b/train_ti.py @@ -111,12 +111,6 @@ def parse_args(): default=[], help="Tokens to create an alias for.", ) - parser.add_argument( - "--inverted_initializer_tokens", - type=str, - nargs="*", - help="A token to use as initializer word.", - ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) @@ -545,18 +539,6 @@ def parse_args(): "--placeholder_tokens and --initializer_tokens must have the same number of items" ) - if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( - args.placeholder_tokens - ) - - if ( - isinstance(args.inverted_initializer_tokens, list) - and len(args.inverted_initializer_tokens) != 0 - ): - args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] - args.initializer_tokens += args.inverted_initializer_tokens - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -872,7 +854,7 @@ def main(): optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=learning_rate, + lr=args.learning_rate, ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) -- cgit v1.2.3-54-g00ecf