From 19ae465203c8dcc0b1179584db632015362b5e44 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 26 Mar 2023 14:27:54 +0200 Subject: Improved inverted tokens --- train_ti.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 83ad46d..6c35d41 100644 --- a/train_ti.py +++ b/train_ti.py @@ -80,6 +80,12 @@ def parse_args(): default=[], help="Tokens to create an alias for." ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) parser.add_argument( "--num_vectors", type=int, @@ -149,7 +155,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=0, + default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -488,6 +494,13 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + + if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -720,6 +733,7 @@ def main(): dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=data_template, + placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, -- cgit v1.2.3-54-g00ecf