From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- train_ti.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 451b61b..c118aab 100644 --- a/train_ti.py +++ b/train_ti.py @@ -15,6 +15,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models +from training.lr import plot_metrics from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -60,6 +61,12 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--skip_first", + type=int, + default=0, + help="Tokens to skip training for.", + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -407,7 +414,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=10, + default=1e0, type=float, help="Embedding decay factor." ) @@ -543,6 +550,10 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e-5 + args.lr_scheduler = "exponential_growth" + if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -596,6 +607,9 @@ def main(): ) def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + if i < args.skip_first: + return + if len(placeholder_tokens) == 1: sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") else: @@ -656,11 +670,12 @@ def main(): warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + end_lr=1e3, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) - trainer( + metrics = trainer( project="textual_inversion", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, @@ -672,6 +687,8 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) + plot_metrics(metrics, output_dir.joinpath("lr.png")) + if args.simultaneous: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: -- cgit v1.2.3-54-g00ecf