From bd951892e300a0e21cb0e10fe261cb647ca160cd Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 18:20:38 +0200 Subject: Added option to use constant LR on cycles > 1 --- train_lora.py | 13 +++++++++++-- train_ti.py | 9 +++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/train_lora.py b/train_lora.py index 5c78664..4d4c16a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -94,6 +94,11 @@ def parse_args(): default=1.0, help="Learning rate decay per cycle." ) + parser.add_argument( + "--cycle_constant", + action="store_true", + help="Use constant LR on cycles > 1." + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -910,7 +915,6 @@ def main(): create_lr_scheduler = partial( get_scheduler, - args.lr_scheduler, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, @@ -918,7 +922,6 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, - warmup_epochs=args.lr_warmup_epochs, mid_point=args.lr_mid_point, ) @@ -971,6 +974,10 @@ def main(): print(f"============ LoRA cycle {training_iter + 1} ============") print("") + if args.cycle_constant and training_iter == 1: + args.lr_scheduler = "constant" + args.lr_warmup_epochs = 0 + params_to_optimize = [] if len(args.placeholder_tokens) != 0: @@ -1005,10 +1012,12 @@ def main(): lora_optimizer = create_optimizer(params_to_optimize) lora_lr_scheduler = create_lr_scheduler( + args.lr_scheduler, gradient_accumulation_steps=args.gradient_accumulation_steps, optimizer=lora_optimizer, num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), train_epochs=num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, ) lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" diff --git a/train_ti.py b/train_ti.py index 45e730a..c452269 100644 --- a/train_ti.py +++ b/train_ti.py @@ -78,6 +78,11 @@ def parse_args(): default=1.0, help="Learning rate decay per cycle." ) + parser.add_argument( + "--cycle_constant", + action="store_true", + help="Use constant LR on cycles > 1." + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -926,6 +931,10 @@ def main(): print(f"------------ TI cycle {training_iter + 1} ------------") print("") + if args.cycle_constant and training_iter == 1: + args.lr_scheduler = "constant" + args.lr_warmup_epochs = 0 + optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), lr=learning_rate, -- cgit v1.2.3-70-g09d2