summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index c118aab..56f9e97 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -166,7 +166,7 @@ def parse_args():
166 parser.add_argument( 166 parser.add_argument(
167 "--tag_dropout", 167 "--tag_dropout",
168 type=float, 168 type=float,
169 default=0, 169 default=0.1,
170 help="Tag dropout probability.", 170 help="Tag dropout probability.",
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
@@ -414,7 +414,7 @@ def parse_args():
414 ) 414 )
415 parser.add_argument( 415 parser.add_argument(
416 "--emb_decay", 416 "--emb_decay",
417 default=1e0, 417 default=1e-2,
418 type=float, 418 type=float,
419 help="Embedding decay factor." 419 help="Embedding decay factor."
420 ) 420 )
@@ -530,7 +530,7 @@ def main():
530 530
531 vae.enable_slicing() 531 vae.enable_slicing()
532 vae.set_use_memory_efficient_attention_xformers(True) 532 vae.set_use_memory_efficient_attention_xformers(True)
533 unet.set_use_memory_efficient_attention_xformers(True) 533 unet.enable_xformers_memory_efficient_attention()
534 534
535 if args.gradient_checkpointing: 535 if args.gradient_checkpointing:
536 unet.enable_gradient_checkpointing() 536 unet.enable_gradient_checkpointing()
@@ -612,8 +612,10 @@ def main():
612 612
613 if len(placeholder_tokens) == 1: 613 if len(placeholder_tokens) == 1:
614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}")
615 metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png")
615 else: 616 else:
616 sample_output_dir = output_dir.joinpath("samples") 617 sample_output_dir = output_dir.joinpath("samples")
618 metrics_output_file = output_dir.joinpath(f"lr.png")
617 619
618 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 620 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
619 tokenizer=tokenizer, 621 tokenizer=tokenizer,
@@ -687,7 +689,7 @@ def main():
687 placeholder_token_ids=placeholder_token_ids, 689 placeholder_token_ids=placeholder_token_ids,
688 ) 690 )
689 691
690 plot_metrics(metrics, output_dir.joinpath("lr.png")) 692 plot_metrics(metrics, metrics_output_file)
691 693
692 if args.simultaneous: 694 if args.simultaneous:
693 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 695 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)