summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index 451b61b..c118aab 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -15,6 +15,7 @@ from slugify import slugify
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, add_placeholder_tokens, get_models 17from training.functional import train, add_placeholder_tokens, get_models
18from training.lr import plot_metrics
18from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 20from training.optimization import get_scheduler
20from training.util import save_args 21from training.util import save_args
@@ -61,6 +62,12 @@ def parse_args():
61 help="The name of the current project.", 62 help="The name of the current project.",
62 ) 63 )
63 parser.add_argument( 64 parser.add_argument(
65 "--skip_first",
66 type=int,
67 default=0,
68 help="Tokens to skip training for.",
69 )
70 parser.add_argument(
64 "--placeholder_tokens", 71 "--placeholder_tokens",
65 type=str, 72 type=str,
66 nargs='*', 73 nargs='*',
@@ -407,7 +414,7 @@ def parse_args():
407 ) 414 )
408 parser.add_argument( 415 parser.add_argument(
409 "--emb_decay", 416 "--emb_decay",
410 default=10, 417 default=1e0,
411 type=float, 418 type=float,
412 help="Embedding decay factor." 419 help="Embedding decay factor."
413 ) 420 )
@@ -543,6 +550,10 @@ def main():
543 args.train_batch_size * accelerator.num_processes 550 args.train_batch_size * accelerator.num_processes
544 ) 551 )
545 552
553 if args.find_lr:
554 args.learning_rate = 1e-5
555 args.lr_scheduler = "exponential_growth"
556
546 if args.use_8bit_adam: 557 if args.use_8bit_adam:
547 try: 558 try:
548 import bitsandbytes as bnb 559 import bitsandbytes as bnb
@@ -596,6 +607,9 @@ def main():
596 ) 607 )
597 608
598 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 609 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
610 if i < args.skip_first:
611 return
612
599 if len(placeholder_tokens) == 1: 613 if len(placeholder_tokens) == 1:
600 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}")
601 else: 615 else:
@@ -656,11 +670,12 @@ def main():
656 warmup_exp=args.lr_warmup_exp, 670 warmup_exp=args.lr_warmup_exp,
657 annealing_exp=args.lr_annealing_exp, 671 annealing_exp=args.lr_annealing_exp,
658 cycles=args.lr_cycles, 672 cycles=args.lr_cycles,
673 end_lr=1e3,
659 train_epochs=args.num_train_epochs, 674 train_epochs=args.num_train_epochs,
660 warmup_epochs=args.lr_warmup_epochs, 675 warmup_epochs=args.lr_warmup_epochs,
661 ) 676 )
662 677
663 trainer( 678 metrics = trainer(
664 project="textual_inversion", 679 project="textual_inversion",
665 train_dataloader=datamodule.train_dataloader, 680 train_dataloader=datamodule.train_dataloader,
666 val_dataloader=datamodule.val_dataloader, 681 val_dataloader=datamodule.val_dataloader,
@@ -672,6 +687,8 @@ def main():
672 placeholder_token_ids=placeholder_token_ids, 687 placeholder_token_ids=placeholder_token_ids,
673 ) 688 )
674 689
690 plot_metrics(metrics, output_dir.joinpath("lr.png"))
691
675 if args.simultaneous: 692 if args.simultaneous:
676 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 693 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
677 else: 694 else: