diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 451b61b..c118aab 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -15,6 +15,7 @@ from slugify import slugify | |||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
| 18 | from training.lr import plot_metrics | ||
| 18 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
| 19 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
| 20 | from training.util import save_args | 21 | from training.util import save_args |
| @@ -61,6 +62,12 @@ def parse_args(): | |||
| 61 | help="The name of the current project.", | 62 | help="The name of the current project.", |
| 62 | ) | 63 | ) |
| 63 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--skip_first", | ||
| 66 | type=int, | ||
| 67 | default=0, | ||
| 68 | help="Tokens to skip training for.", | ||
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 64 | "--placeholder_tokens", | 71 | "--placeholder_tokens", |
| 65 | type=str, | 72 | type=str, |
| 66 | nargs='*', | 73 | nargs='*', |
| @@ -407,7 +414,7 @@ def parse_args(): | |||
| 407 | ) | 414 | ) |
| 408 | parser.add_argument( | 415 | parser.add_argument( |
| 409 | "--emb_decay", | 416 | "--emb_decay", |
| 410 | default=10, | 417 | default=1e0, |
| 411 | type=float, | 418 | type=float, |
| 412 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
| 413 | ) | 420 | ) |
| @@ -543,6 +550,10 @@ def main(): | |||
| 543 | args.train_batch_size * accelerator.num_processes | 550 | args.train_batch_size * accelerator.num_processes |
| 544 | ) | 551 | ) |
| 545 | 552 | ||
| 553 | if args.find_lr: | ||
| 554 | args.learning_rate = 1e-5 | ||
| 555 | args.lr_scheduler = "exponential_growth" | ||
| 556 | |||
| 546 | if args.use_8bit_adam: | 557 | if args.use_8bit_adam: |
| 547 | try: | 558 | try: |
| 548 | import bitsandbytes as bnb | 559 | import bitsandbytes as bnb |
| @@ -596,6 +607,9 @@ def main(): | |||
| 596 | ) | 607 | ) |
| 597 | 608 | ||
| 598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 609 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 610 | if i < args.skip_first: | ||
| 611 | return | ||
| 612 | |||
| 599 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
| 600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
| 601 | else: | 615 | else: |
| @@ -656,11 +670,12 @@ def main(): | |||
| 656 | warmup_exp=args.lr_warmup_exp, | 670 | warmup_exp=args.lr_warmup_exp, |
| 657 | annealing_exp=args.lr_annealing_exp, | 671 | annealing_exp=args.lr_annealing_exp, |
| 658 | cycles=args.lr_cycles, | 672 | cycles=args.lr_cycles, |
| 673 | end_lr=1e3, | ||
| 659 | train_epochs=args.num_train_epochs, | 674 | train_epochs=args.num_train_epochs, |
| 660 | warmup_epochs=args.lr_warmup_epochs, | 675 | warmup_epochs=args.lr_warmup_epochs, |
| 661 | ) | 676 | ) |
| 662 | 677 | ||
| 663 | trainer( | 678 | metrics = trainer( |
| 664 | project="textual_inversion", | 679 | project="textual_inversion", |
| 665 | train_dataloader=datamodule.train_dataloader, | 680 | train_dataloader=datamodule.train_dataloader, |
| 666 | val_dataloader=datamodule.val_dataloader, | 681 | val_dataloader=datamodule.val_dataloader, |
| @@ -672,6 +687,8 @@ def main(): | |||
| 672 | placeholder_token_ids=placeholder_token_ids, | 687 | placeholder_token_ids=placeholder_token_ids, |
| 673 | ) | 688 | ) |
| 674 | 689 | ||
| 690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
| 691 | |||
| 675 | if args.simultaneous: | 692 | if args.simultaneous: |
| 676 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 677 | else: | 694 | else: |
