summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-23 16:19:54 +0200
committerVolpeon <git@volpeon.ink>2023-04-23 16:19:54 +0200
commit08cb1f476b8676e87eb42aafee1aa07e5b275e23 (patch)
tree1f2c78c1efeebb1d0ca9069a812b9a0cee99ed3b
parentUpdate (diff)
downloadtextual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.tar.gz
textual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.tar.bz2
textual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.zip
Fix cycle loop
-rw-r--r--train_lora.py8
-rw-r--r--train_ti.py8
2 files changed, 10 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py
index 1d1485d..c197206 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -1057,17 +1057,19 @@ def main():
1057 lr_scheduler = "one_cycle" 1057 lr_scheduler = "one_cycle"
1058 lr_warmup_epochs = args.lr_warmup_epochs 1058 lr_warmup_epochs = args.lr_warmup_epochs
1059 lr_cycles = args.lr_cycles 1059 lr_cycles = args.lr_cycles
1060 if response.lower().strip() == "w": 1060 elif response.lower().strip() == "w":
1061 lr_scheduler = "constant_with_warmup" 1061 lr_scheduler = "constant_with_warmup"
1062 lr_warmup_epochs = num_train_epochs 1062 lr_warmup_epochs = num_train_epochs
1063 if response.lower().strip() == "c": 1063 elif response.lower().strip() == "c":
1064 lr_scheduler = "constant" 1064 lr_scheduler = "constant"
1065 if response.lower().strip() == "d": 1065 elif response.lower().strip() == "d":
1066 lr_scheduler = "cosine" 1066 lr_scheduler = "cosine"
1067 lr_warmup_epochs = 0 1067 lr_warmup_epochs = 0
1068 lr_cycles = 1 1068 lr_cycles = 1
1069 elif response.lower().strip() == "s": 1069 elif response.lower().strip() == "s":
1070 break 1070 break
1071 else:
1072 continue
1071 1073
1072 print("") 1074 print("")
1073 print(f"============ LoRA cycle {training_iter + 1}: {response} ============") 1075 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
diff --git a/train_ti.py b/train_ti.py
index 84ca296..d1e5467 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -937,17 +937,19 @@ def main():
937 lr_scheduler = "one_cycle" 937 lr_scheduler = "one_cycle"
938 lr_warmup_epochs = args.lr_warmup_epochs 938 lr_warmup_epochs = args.lr_warmup_epochs
939 lr_cycles = args.lr_cycles 939 lr_cycles = args.lr_cycles
940 if response.lower().strip() == "w": 940 elif response.lower().strip() == "w":
941 lr_scheduler = "constant_with_warmup" 941 lr_scheduler = "constant_with_warmup"
942 lr_warmup_epochs = num_train_epochs 942 lr_warmup_epochs = num_train_epochs
943 if response.lower().strip() == "c": 943 elif response.lower().strip() == "c":
944 lr_scheduler = "constant" 944 lr_scheduler = "constant"
945 if response.lower().strip() == "d": 945 elif response.lower().strip() == "d":
946 lr_scheduler = "cosine" 946 lr_scheduler = "cosine"
947 lr_warmup_epochs = 0 947 lr_warmup_epochs = 0
948 lr_cycles = 1 948 lr_cycles = 1
949 elif response.lower().strip() == "s": 949 elif response.lower().strip() == "s":
950 break 950 break
951 else:
952 continue
951 953
952 print("") 954 print("")
953 print(f"------------ TI cycle {training_iter + 1}: {response} ------------") 955 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")