diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-17 10:46:20 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-17 10:46:20 +0200 |
| commit | d07364e55483e81603704a978c0050d58d357a77 (patch) | |
| tree | 939e8f809f90142ef507fe09ef0a3f08c066a353 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.gz textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.tar.bz2 textual-inversion-diff-d07364e55483e81603704a978c0050d58d357a77.zip | |
Fix
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 880320f..b00b0d7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -925,18 +925,18 @@ def main(): | |||
| 925 | if len(auto_cycles) != 0: | 925 | if len(auto_cycles) != 0: |
| 926 | response = auto_cycles.pop(0) | 926 | response = auto_cycles.pop(0) |
| 927 | else: | 927 | else: |
| 928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | response = input( |
| 929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | ||
| 929 | 930 | ||
| 930 | if response.lower().strip() == "o": | 931 | if response.lower().strip() == "o": |
| 931 | lr_scheduler = "one_cycle" | 932 | lr_scheduler = "one_cycle" |
| 932 | lr_warmup_epochs = args.lr_warmup_epochs | 933 | lr_warmup_epochs = args.lr_warmup_epochs |
| 933 | lr_cycles = args.lr_cycles | 934 | lr_cycles = args.lr_cycles |
| 934 | if response.lower().strip() == "w": | 935 | if response.lower().strip() == "w": |
| 935 | lr_scheduler = "constant" | 936 | lr_scheduler = "constant_with_warmup" |
| 936 | lr_warmup_epochs = num_train_epochs | 937 | lr_warmup_epochs = num_train_epochs |
| 937 | if response.lower().strip() == "c": | 938 | if response.lower().strip() == "c": |
| 938 | lr_scheduler = "constant" | 939 | lr_scheduler = "constant" |
| 939 | lr_warmup_epochs = 0 | ||
| 940 | if response.lower().strip() == "d": | 940 | if response.lower().strip() == "d": |
| 941 | lr_scheduler = "cosine" | 941 | lr_scheduler = "cosine" |
| 942 | lr_warmup_epochs = 0 | 942 | lr_warmup_epochs = 0 |
