summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-17 10:46:20 +0200
committerVolpeon <git@volpeon.ink>2023-04-17 10:46:20 +0200
commitd07364e55483e81603704a978c0050d58d357a77 (patch)
tree939e8f809f90142ef507fe09ef0a3f08c066a353
parentFix (diff)
downloadtextual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.gz
textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.bz2
textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.zip
Fix
-rw-r--r--train_lora.py5
-rw-r--r--train_ti.py6
-rw-r--r--training/strategy/ti.py2
3 files changed, 6 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index ba5aee1..d0313fe 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -967,18 +967,17 @@ def main():
967 if len(auto_cycles) != 0: 967 if len(auto_cycles) != 0:
968 response = auto_cycles.pop(0) 968 response = auto_cycles.pop(0)
969 else: 969 else:
970 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 970 response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
971 971
972 if response.lower().strip() == "o": 972 if response.lower().strip() == "o":
973 lr_scheduler = "one_cycle" 973 lr_scheduler = "one_cycle"
974 lr_warmup_epochs = args.lr_warmup_epochs 974 lr_warmup_epochs = args.lr_warmup_epochs
975 lr_cycles = args.lr_cycles 975 lr_cycles = args.lr_cycles
976 if response.lower().strip() == "w": 976 if response.lower().strip() == "w":
977 lr_scheduler = "constant" 977 lr_scheduler = "constant_with_warmup"
978 lr_warmup_epochs = num_train_epochs 978 lr_warmup_epochs = num_train_epochs
979 if response.lower().strip() == "c": 979 if response.lower().strip() == "c":
980 lr_scheduler = "constant" 980 lr_scheduler = "constant"
981 lr_warmup_epochs = 0
982 if response.lower().strip() == "d": 981 if response.lower().strip() == "d":
983 lr_scheduler = "cosine" 982 lr_scheduler = "cosine"
984 lr_warmup_epochs = 0 983 lr_warmup_epochs = 0
diff --git a/train_ti.py b/train_ti.py
index 880320f..b00b0d7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -925,18 +925,18 @@ def main():
925 if len(auto_cycles) != 0: 925 if len(auto_cycles) != 0:
926 response = auto_cycles.pop(0) 926 response = auto_cycles.pop(0)
927 else: 927 else:
928 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 928 response = input(
929 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
929 930
930 if response.lower().strip() == "o": 931 if response.lower().strip() == "o":
931 lr_scheduler = "one_cycle" 932 lr_scheduler = "one_cycle"
932 lr_warmup_epochs = args.lr_warmup_epochs 933 lr_warmup_epochs = args.lr_warmup_epochs
933 lr_cycles = args.lr_cycles 934 lr_cycles = args.lr_cycles
934 if response.lower().strip() == "w": 935 if response.lower().strip() == "w":
935 lr_scheduler = "constant" 936 lr_scheduler = "constant_with_warmup"
936 lr_warmup_epochs = num_train_epochs 937 lr_warmup_epochs = num_train_epochs
937 if response.lower().strip() == "c": 938 if response.lower().strip() == "c":
938 lr_scheduler = "constant" 939 lr_scheduler = "constant"
939 lr_warmup_epochs = 0
940 if response.lower().strip() == "d": 940 if response.lower().strip() == "d":
941 lr_scheduler = "cosine" 941 lr_scheduler = "cosine"
942 lr_warmup_epochs = 0 942 lr_warmup_epochs = 0
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6bbff64..f330cb7 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks(
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"] 122 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
123 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
124 124
125 if lambda_ != 0: 125 if lambda_ != 0: