From 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:38:49 +0200 Subject: Update --- train_ti.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,6 +3,7 @@ import datetime import logging from functools import partial from pathlib import Path +from typing import Union import math import torch @@ -74,6 +75,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--filter_tokens", + type=str, + nargs='*', + help="Tokens to filter the dataset by." + ) parser.add_argument( "--initializer_noise", type=float, @@ -538,6 +545,12 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if args.sequential: args.alias_tokens += [ item @@ -779,13 +792,11 @@ def main(): sample_image_size=args.sample_image_size, ) - def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): if len(placeholder_tokens) == 1: sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" - metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" else: sample_output_dir = output_dir / "samples" - metrics_output_file = output_dir / "lr.png" placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -800,6 +811,8 @@ def main(): print(f"{i + 1}: {stats}") + filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -820,7 +833,7 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, - filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -834,7 +847,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.parameters(), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf