From 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:38:49 +0200 Subject: Update --- train_lora.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 1626be6..e4b5546 100644 --- a/train_lora.py +++ b/train_lora.py @@ -92,6 +92,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--filter_tokens", + type=str, + nargs='*', + help="Tokens to filter the dataset by." + ) parser.add_argument( "--initializer_noise", type=float, @@ -592,6 +598,12 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if isinstance(args.collection, str): args.collection = [args.collection] @@ -890,7 +902,7 @@ def main(): pti_datamodule = create_datamodule( batch_size=args.pti_batch_size, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), ) pti_datamodule.setup() @@ -906,7 +918,7 @@ def main(): pti_optimizer = create_optimizer( [ { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), "lr": args.learning_rate_pti, "weight_decay": 0, }, @@ -937,7 +949,7 @@ def main(): sample_frequency=pti_sample_frequency, ) - # embeddings.persist() + embeddings.persist() # LORA # -------------------------------------------------------------------------------- @@ -962,13 +974,13 @@ def main(): params_to_optimize = [] group_labels = [] - if len(args.placeholder_tokens) != 0: - params_to_optimize.append({ - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_text, - "weight_decay": 0, - }) - group_labels.append("emb") + # if len(args.placeholder_tokens) != 0: + # params_to_optimize.append({ + # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), + # "lr": args.learning_rate_text, + # "weight_decay": 0, + # }) + # group_labels.append("emb") params_to_optimize += [ { "params": ( -- cgit v1.2.3-54-g00ecf