diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -166,7 +166,7 @@ def parse_args(): | |||
| 166 | parser.add_argument( | 166 | parser.add_argument( |
| 167 | "--tag_dropout", | 167 | "--tag_dropout", |
| 168 | type=float, | 168 | type=float, |
| 169 | default=0, | 169 | default=0.1, |
| 170 | help="Tag dropout probability.", | 170 | help="Tag dropout probability.", |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| @@ -414,7 +414,7 @@ def parse_args(): | |||
| 414 | ) | 414 | ) |
| 415 | parser.add_argument( | 415 | parser.add_argument( |
| 416 | "--emb_decay", | 416 | "--emb_decay", |
| 417 | default=1e0, | 417 | default=1e-2, |
| 418 | type=float, | 418 | type=float, |
| 419 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
| 420 | ) | 420 | ) |
| @@ -530,7 +530,7 @@ def main(): | |||
| 530 | 530 | ||
| 531 | vae.enable_slicing() | 531 | vae.enable_slicing() |
| 532 | vae.set_use_memory_efficient_attention_xformers(True) | 532 | vae.set_use_memory_efficient_attention_xformers(True) |
| 533 | unet.set_use_memory_efficient_attention_xformers(True) | 533 | unet.enable_xformers_memory_efficient_attention() |
| 534 | 534 | ||
| 535 | if args.gradient_checkpointing: | 535 | if args.gradient_checkpointing: |
| 536 | unet.enable_gradient_checkpointing() | 536 | unet.enable_gradient_checkpointing() |
| @@ -612,8 +612,10 @@ def main(): | |||
| 612 | 612 | ||
| 613 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
| 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
| 615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | ||
| 615 | else: | 616 | else: |
| 616 | sample_output_dir = output_dir.joinpath("samples") | 617 | sample_output_dir = output_dir.joinpath("samples") |
| 618 | metrics_output_file = output_dir.joinpath(f"lr.png") | ||
| 617 | 619 | ||
| 618 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 619 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
| @@ -687,7 +689,7 @@ def main(): | |||
| 687 | placeholder_token_ids=placeholder_token_ids, | 689 | placeholder_token_ids=placeholder_token_ids, |
| 688 | ) | 690 | ) |
| 689 | 691 | ||
| 690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 692 | plot_metrics(metrics, metrics_output_file) |
| 691 | 693 | ||
| 692 | if args.simultaneous: | 694 | if args.simultaneous: |
| 693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 695 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
