From d07364e55483e81603704a978c0050d58d357a77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Apr 2023 10:46:20 +0200 Subject: Fix --- train_lora.py | 5 ++--- train_ti.py | 6 +++--- training/strategy/ti.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/train_lora.py b/train_lora.py index ba5aee1..d0313fe 100644 --- a/train_lora.py +++ b/train_lora.py @@ -967,18 +967,17 @@ def main(): if len(auto_cycles) != 0: response = auto_cycles.pop(0) else: - response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") if response.lower().strip() == "o": lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles if response.lower().strip() == "w": - lr_scheduler = "constant" + lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs if response.lower().strip() == "c": lr_scheduler = "constant" - lr_warmup_epochs = 0 if response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 diff --git a/train_ti.py b/train_ti.py index 880320f..b00b0d7 100644 --- a/train_ti.py +++ b/train_ti.py @@ -925,18 +925,18 @@ def main(): if len(auto_cycles) != 0: response = auto_cycles.pop(0) else: - response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + response = input( + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") if response.lower().strip() == "o": lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles if response.lower().strip() == "w": - lr_scheduler = "constant" + lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs if response.lower().strip() == "c": lr_scheduler = "constant" - lr_warmup_epochs = 0 if response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bbff64..f330cb7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks( ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) if use_emb_decay and w is not None: - lr = lrs["emb"] or lrs["0"] + lr = lrs["emb"] if "emb" in lrs else lrs["0"] lambda_ = emb_decay * lr if lambda_ != 0: -- cgit v1.2.3-70-g09d2