diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-23 16:19:54 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-23 16:19:54 +0200 |
| commit | 08cb1f476b8676e87eb42aafee1aa07e5b275e23 (patch) | |
| tree | 1f2c78c1efeebb1d0ca9069a812b9a0cee99ed3b /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.tar.gz textual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.tar.bz2 textual-inversion-diff-08cb1f476b8676e87eb42aafee1aa07e5b275e23.zip | |
Fix cycle loop
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 84ca296..d1e5467 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -937,17 +937,19 @@ def main(): | |||
| 937 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" |
| 938 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs |
| 939 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles |
| 940 | if response.lower().strip() == "w": | 940 | elif response.lower().strip() == "w": |
| 941 | lr_scheduler = "constant_with_warmup" | 941 | lr_scheduler = "constant_with_warmup" |
| 942 | lr_warmup_epochs = num_train_epochs | 942 | lr_warmup_epochs = num_train_epochs |
| 943 | if response.lower().strip() == "c": | 943 | elif response.lower().strip() == "c": |
| 944 | lr_scheduler = "constant" | 944 | lr_scheduler = "constant" |
| 945 | if response.lower().strip() == "d": | 945 | elif response.lower().strip() == "d": |
| 946 | lr_scheduler = "cosine" | 946 | lr_scheduler = "cosine" |
| 947 | lr_warmup_epochs = 0 | 947 | lr_warmup_epochs = 0 |
| 948 | lr_cycles = 1 | 948 | lr_cycles = 1 |
| 949 | elif response.lower().strip() == "s": | 949 | elif response.lower().strip() == "s": |
| 950 | break | 950 | break |
| 951 | else: | ||
| 952 | continue | ||
| 951 | 953 | ||
| 952 | print("") | 954 | print("") |
| 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 955 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
