diff options
author | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
commit | 3575d041f1507811b577fd2c653171fb51c0a386 (patch) | |
tree | 702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /train_ti.py | |
parent | Move Accelerator preparation into strategy (diff) | |
download | textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2 textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip |
Restored LR finder
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 451b61b..c118aab 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -15,6 +15,7 @@ from slugify import slugify | |||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
18 | from training.lr import plot_metrics | ||
18 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
19 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
20 | from training.util import save_args | 21 | from training.util import save_args |
@@ -61,6 +62,12 @@ def parse_args(): | |||
61 | help="The name of the current project.", | 62 | help="The name of the current project.", |
62 | ) | 63 | ) |
63 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--skip_first", | ||
66 | type=int, | ||
67 | default=0, | ||
68 | help="Tokens to skip training for.", | ||
69 | ) | ||
70 | parser.add_argument( | ||
64 | "--placeholder_tokens", | 71 | "--placeholder_tokens", |
65 | type=str, | 72 | type=str, |
66 | nargs='*', | 73 | nargs='*', |
@@ -407,7 +414,7 @@ def parse_args(): | |||
407 | ) | 414 | ) |
408 | parser.add_argument( | 415 | parser.add_argument( |
409 | "--emb_decay", | 416 | "--emb_decay", |
410 | default=10, | 417 | default=1e0, |
411 | type=float, | 418 | type=float, |
412 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
413 | ) | 420 | ) |
@@ -543,6 +550,10 @@ def main(): | |||
543 | args.train_batch_size * accelerator.num_processes | 550 | args.train_batch_size * accelerator.num_processes |
544 | ) | 551 | ) |
545 | 552 | ||
553 | if args.find_lr: | ||
554 | args.learning_rate = 1e-5 | ||
555 | args.lr_scheduler = "exponential_growth" | ||
556 | |||
546 | if args.use_8bit_adam: | 557 | if args.use_8bit_adam: |
547 | try: | 558 | try: |
548 | import bitsandbytes as bnb | 559 | import bitsandbytes as bnb |
@@ -596,6 +607,9 @@ def main(): | |||
596 | ) | 607 | ) |
597 | 608 | ||
598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 609 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
610 | if i < args.skip_first: | ||
611 | return | ||
612 | |||
599 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
601 | else: | 615 | else: |
@@ -656,11 +670,12 @@ def main(): | |||
656 | warmup_exp=args.lr_warmup_exp, | 670 | warmup_exp=args.lr_warmup_exp, |
657 | annealing_exp=args.lr_annealing_exp, | 671 | annealing_exp=args.lr_annealing_exp, |
658 | cycles=args.lr_cycles, | 672 | cycles=args.lr_cycles, |
673 | end_lr=1e3, | ||
659 | train_epochs=args.num_train_epochs, | 674 | train_epochs=args.num_train_epochs, |
660 | warmup_epochs=args.lr_warmup_epochs, | 675 | warmup_epochs=args.lr_warmup_epochs, |
661 | ) | 676 | ) |
662 | 677 | ||
663 | trainer( | 678 | metrics = trainer( |
664 | project="textual_inversion", | 679 | project="textual_inversion", |
665 | train_dataloader=datamodule.train_dataloader, | 680 | train_dataloader=datamodule.train_dataloader, |
666 | val_dataloader=datamodule.val_dataloader, | 681 | val_dataloader=datamodule.val_dataloader, |
@@ -672,6 +687,8 @@ def main(): | |||
672 | placeholder_token_ids=placeholder_token_ids, | 687 | placeholder_token_ids=placeholder_token_ids, |
673 | ) | 688 | ) |
674 | 689 | ||
690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
691 | |||
675 | if args.simultaneous: | 692 | if args.simultaneous: |
676 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
677 | else: | 694 | else: |