summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-17 15:18:20 +0100
committerVolpeon <git@volpeon.ink>2023-03-17 15:18:20 +0100
commit8abbd633d8ee7500058b2f1f69a6d6611b5a4450 (patch)
treef60d6e384966ba05354b30f08a32a38279b56165 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.gz
textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.bz2
textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.zip
Test: https://arxiv.org/pdf/2303.09556.pdf
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py38
1 files changed, 26 insertions, 12 deletions
diff --git a/train_ti.py b/train_ti.py
index 81938c8..fd23517 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -62,12 +62,6 @@ def parse_args():
62 help="The name of the current project.", 62 help="The name of the current project.",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--skip_first",
66 type=int,
67 default=0,
68 help="Tokens to skip training for.",
69 )
70 parser.add_argument(
71 "--placeholder_tokens", 65 "--placeholder_tokens",
72 type=str, 66 type=str,
73 nargs='*', 67 nargs='*',
@@ -80,6 +74,13 @@ def parse_args():
80 help="A token to use as initializer word." 74 help="A token to use as initializer word."
81 ) 75 )
82 parser.add_argument( 76 parser.add_argument(
77 "--alias_tokens",
78 type=str,
79 nargs='*',
80 default=[],
81 help="Tokens to create an alias for."
82 )
83 parser.add_argument(
83 "--num_vectors", 84 "--num_vectors",
84 type=int, 85 type=int,
85 nargs='*', 86 nargs='*',
@@ -420,7 +421,7 @@ def parse_args():
420 ) 421 )
421 parser.add_argument( 422 parser.add_argument(
422 "--emb_decay", 423 "--emb_decay",
423 default=1e-2, 424 default=1e2,
424 type=float, 425 type=float,
425 help="Embedding decay factor." 426 help="Embedding decay factor."
426 ) 427 )
@@ -482,6 +483,9 @@ def parse_args():
482 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 483 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 484 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
484 485
486 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
487 raise ValueError("--alias_tokens must be a list with an even number of items")
488
485 if args.sequential: 489 if args.sequential:
486 if isinstance(args.train_data_template, str): 490 if isinstance(args.train_data_template, str):
487 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 491 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
@@ -543,8 +547,8 @@ def main():
543 tokenizer.set_dropout(args.vector_dropout) 547 tokenizer.set_dropout(args.vector_dropout)
544 548
545 vae.enable_slicing() 549 vae.enable_slicing()
546 vae.set_use_memory_efficient_attention_xformers(True) 550 # vae.set_use_memory_efficient_attention_xformers(True)
547 unet.enable_xformers_memory_efficient_attention() 551 # unet.enable_xformers_memory_efficient_attention()
548 # unet = torch.compile(unet) 552 # unet = torch.compile(unet)
549 553
550 if args.gradient_checkpointing: 554 if args.gradient_checkpointing:
@@ -559,6 +563,19 @@ def main():
559 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 563 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
560 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 564 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
561 565
566 if len(args.alias_tokens) != 0:
567 alias_placeholder_tokens = args.alias_tokens[::2]
568 alias_initializer_tokens = args.alias_tokens[1::2]
569
570 added_tokens, added_ids = add_placeholder_tokens(
571 tokenizer=tokenizer,
572 embeddings=embeddings,
573 placeholder_tokens=alias_placeholder_tokens,
574 initializer_tokens=alias_initializer_tokens
575 )
576 embeddings.persist()
577 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
578
562 if args.scale_lr: 579 if args.scale_lr:
563 args.learning_rate = ( 580 args.learning_rate = (
564 args.learning_rate * args.gradient_accumulation_steps * 581 args.learning_rate * args.gradient_accumulation_steps *
@@ -633,9 +650,6 @@ def main():
633 ) 650 )
634 651
635 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 652 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
636 if i < args.skip_first:
637 return
638
639 if len(placeholder_tokens) == 1: 653 if len(placeholder_tokens) == 1:
640 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" 654 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}"
641 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" 655 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png"