From 08cb1f476b8676e87eb42aafee1aa07e5b275e23 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 23 Apr 2023 16:19:54 +0200 Subject: Fix cycle loop --- train_lora.py | 8 +++++--- train_ti.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/train_lora.py b/train_lora.py index 1d1485d..c197206 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1057,17 +1057,19 @@ def main(): lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles - if response.lower().strip() == "w": + elif response.lower().strip() == "w": lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs - if response.lower().strip() == "c": + elif response.lower().strip() == "c": lr_scheduler = "constant" - if response.lower().strip() == "d": + elif response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 lr_cycles = 1 elif response.lower().strip() == "s": break + else: + continue print("") print(f"============ LoRA cycle {training_iter + 1}: {response} ============") diff --git a/train_ti.py b/train_ti.py index 84ca296..d1e5467 100644 --- a/train_ti.py +++ b/train_ti.py @@ -937,17 +937,19 @@ def main(): lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles - if response.lower().strip() == "w": + elif response.lower().strip() == "w": lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs - if response.lower().strip() == "c": + elif response.lower().strip() == "c": lr_scheduler = "constant" - if response.lower().strip() == "d": + elif response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 lr_cycles = 1 elif response.lower().strip() == "s": break + else: + continue print("") print(f"------------ TI cycle {training_iter + 1}: {response} ------------") -- cgit v1.2.3-70-g09d2