From 71f4a40bb48be4f2759ba2d83faff39691cb2955 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 19:03:25 +0200 Subject: Improved automation caps --- train_ti.py | 53 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 20 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -68,9 +68,9 @@ def parse_args(): ) parser.add_argument( "--auto_cycles", - type=int, - default=1, - help="How many cycles to run automatically." + type=str, + default="o", + help="Cycles to run automatically." ) parser.add_argument( "--cycle_decay", @@ -78,11 +78,6 @@ def parse_args(): default=1.0, help="Learning rate decay per cycle." ) - parser.add_argument( - "--cycle_constant", - action="store_true", - help="Use constant LR on cycles > 1." - ) parser.add_argument( "--placeholder_tokens", type=str, @@ -921,27 +916,45 @@ def main(): sample_output_dir = output_dir / project / "samples" + auto_cycles = list(args.auto_cycles) + lr_scheduler = args.lr_scheduler + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + while True: - if training_iter >= args.auto_cycles: - response = input("Run another cycle? [y/n] ") - if response.lower().strip() == "n": - break + if len(auto_cycles) != 0: + response = auto_cycles.pop(0) + else: + response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + + if response.lower().strip() == "o": + lr_scheduler = "one_cycle" + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + if response.lower().strip() == "w": + lr_scheduler = "constant" + lr_warmup_epochs = num_train_epochs + if response.lower().strip() == "c": + lr_scheduler = "constant" + lr_warmup_epochs = 0 + if response.lower().strip() == "d": + lr_scheduler = "cosine" + lr_warmup_epochs = 0 + lr_cycles = 1 + elif response.lower().strip() == "s": + break print("") print(f"------------ TI cycle {training_iter + 1} ------------") print("") - if args.cycle_constant and training_iter == 1: - args.lr_scheduler = "constant" - args.lr_warmup_epochs = 0 - optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), lr=learning_rate, ) lr_scheduler = get_scheduler( - args.lr_scheduler, + lr_scheduler, optimizer=optimizer, num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -950,10 +963,10 @@ def main(): annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, + cycles=lr_cycles, end_lr=1e3, train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, + warmup_epochs=lr_warmup_epochs, mid_point=args.lr_mid_point, ) @@ -966,7 +979,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, global_step_offset=training_iter * num_train_steps, - initial_samples=training_iter == 0, + cycle=training_iter, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, -- cgit v1.2.3-54-g00ecf