From 8abbd633d8ee7500058b2f1f69a6d6611b5a4450 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Mar 2023 15:18:20 +0100 Subject: Test: https://arxiv.org/pdf/2303.09556.pdf --- train_ti.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 81938c8..fd23517 100644 --- a/train_ti.py +++ b/train_ti.py @@ -61,12 +61,6 @@ def parse_args(): default=None, help="The name of the current project.", ) - parser.add_argument( - "--skip_first", - type=int, - default=0, - help="Tokens to skip training for.", - ) parser.add_argument( "--placeholder_tokens", type=str, @@ -79,6 +73,13 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs='*', + default=[], + help="Tokens to create an alias for." + ) parser.add_argument( "--num_vectors", type=int, @@ -420,7 +421,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=1e-2, + default=1e2, type=float, help="Embedding decay factor." ) @@ -482,6 +483,9 @@ def parse_args(): if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -543,8 +547,8 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() + # vae.set_use_memory_efficient_attention_xformers(True) + # unet.enable_xformers_memory_efficient_attention() # unet = torch.compile(unet) if args.gradient_checkpointing: @@ -559,6 +563,19 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens + ) + embeddings.persist() + print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -633,9 +650,6 @@ def main(): ) def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): - if i < args.skip_first: - return - if len(placeholder_tokens) == 1: sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" -- cgit v1.2.3-54-g00ecf