From 08cb1f476b8676e87eb42aafee1aa07e5b275e23 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 23 Apr 2023 16:19:54 +0200 Subject: Fix cycle loop --- train_ti.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 84ca296..d1e5467 100644 --- a/train_ti.py +++ b/train_ti.py @@ -937,17 +937,19 @@ def main(): lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles - if response.lower().strip() == "w": + elif response.lower().strip() == "w": lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs - if response.lower().strip() == "c": + elif response.lower().strip() == "c": lr_scheduler = "constant" - if response.lower().strip() == "d": + elif response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 lr_cycles = 1 elif response.lower().strip() == "s": break + else: + continue print("") print(f"------------ TI cycle {training_iter + 1}: {response} ------------") -- cgit v1.2.3-54-g00ecf