From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- train_ti.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py @@ -166,7 +166,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -414,7 +414,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=1e0, + default=1e-2, type=float, help="Embedding decay factor." ) @@ -530,7 +530,7 @@ def main(): vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -612,8 +612,10 @@ def main(): if len(placeholder_tokens) == 1: sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") + metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") else: sample_output_dir = output_dir.joinpath("samples") + metrics_output_file = output_dir.joinpath(f"lr.png") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -687,7 +689,7 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) - plot_metrics(metrics, output_dir.joinpath("lr.png")) + plot_metrics(metrics, metrics_output_file) if args.simultaneous: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) -- cgit v1.2.3-54-g00ecf