diff options
author | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
commit | 7ccd4614a56cfd6ecacba85605f338593f1059f0 (patch) | |
tree | fa9882b256c752705bc42229bac4e00ed7088643 /train_ti.py | |
parent | Restored LR finder (diff) | |
download | textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.gz textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.bz2 textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.zip |
Add Lora
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -166,7 +166,7 @@ def parse_args(): | |||
166 | parser.add_argument( | 166 | parser.add_argument( |
167 | "--tag_dropout", | 167 | "--tag_dropout", |
168 | type=float, | 168 | type=float, |
169 | default=0, | 169 | default=0.1, |
170 | help="Tag dropout probability.", | 170 | help="Tag dropout probability.", |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
@@ -414,7 +414,7 @@ def parse_args(): | |||
414 | ) | 414 | ) |
415 | parser.add_argument( | 415 | parser.add_argument( |
416 | "--emb_decay", | 416 | "--emb_decay", |
417 | default=1e0, | 417 | default=1e-2, |
418 | type=float, | 418 | type=float, |
419 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
420 | ) | 420 | ) |
@@ -530,7 +530,7 @@ def main(): | |||
530 | 530 | ||
531 | vae.enable_slicing() | 531 | vae.enable_slicing() |
532 | vae.set_use_memory_efficient_attention_xformers(True) | 532 | vae.set_use_memory_efficient_attention_xformers(True) |
533 | unet.set_use_memory_efficient_attention_xformers(True) | 533 | unet.enable_xformers_memory_efficient_attention() |
534 | 534 | ||
535 | if args.gradient_checkpointing: | 535 | if args.gradient_checkpointing: |
536 | unet.enable_gradient_checkpointing() | 536 | unet.enable_gradient_checkpointing() |
@@ -612,8 +612,10 @@ def main(): | |||
612 | 612 | ||
613 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | ||
615 | else: | 616 | else: |
616 | sample_output_dir = output_dir.joinpath("samples") | 617 | sample_output_dir = output_dir.joinpath("samples") |
618 | metrics_output_file = output_dir.joinpath(f"lr.png") | ||
617 | 619 | ||
618 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
619 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
@@ -687,7 +689,7 @@ def main(): | |||
687 | placeholder_token_ids=placeholder_token_ids, | 689 | placeholder_token_ids=placeholder_token_ids, |
688 | ) | 690 | ) |
689 | 691 | ||
690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 692 | plot_metrics(metrics, metrics_output_file) |
691 | 693 | ||
692 | if args.simultaneous: | 694 | if args.simultaneous: |
693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 695 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |