diff options
author | Volpeon <git@volpeon.ink> | 2023-04-17 10:46:20 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-17 10:46:20 +0200 |
commit | d07364e55483e81603704a978c0050d58d357a77 (patch) | |
tree | 939e8f809f90142ef507fe09ef0a3f08c066a353 /train_ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.gz textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.bz2 textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.zip |
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 880320f..b00b0d7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -925,18 +925,18 @@ def main(): | |||
925 | if len(auto_cycles) != 0: | 925 | if len(auto_cycles) != 0: |
926 | response = auto_cycles.pop(0) | 926 | response = auto_cycles.pop(0) |
927 | else: | 927 | else: |
928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | response = input( |
929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | ||
929 | 930 | ||
930 | if response.lower().strip() == "o": | 931 | if response.lower().strip() == "o": |
931 | lr_scheduler = "one_cycle" | 932 | lr_scheduler = "one_cycle" |
932 | lr_warmup_epochs = args.lr_warmup_epochs | 933 | lr_warmup_epochs = args.lr_warmup_epochs |
933 | lr_cycles = args.lr_cycles | 934 | lr_cycles = args.lr_cycles |
934 | if response.lower().strip() == "w": | 935 | if response.lower().strip() == "w": |
935 | lr_scheduler = "constant" | 936 | lr_scheduler = "constant_with_warmup" |
936 | lr_warmup_epochs = num_train_epochs | 937 | lr_warmup_epochs = num_train_epochs |
937 | if response.lower().strip() == "c": | 938 | if response.lower().strip() == "c": |
938 | lr_scheduler = "constant" | 939 | lr_scheduler = "constant" |
939 | lr_warmup_epochs = 0 | ||
940 | if response.lower().strip() == "d": | 940 | if response.lower().strip() == "d": |
941 | lr_scheduler = "cosine" | 941 | lr_scheduler = "cosine" |
942 | lr_warmup_epochs = 0 | 942 | lr_warmup_epochs = 0 |