diff options
-rw-r--r-- | train_lora.py | 5 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 6 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index ba5aee1..d0313fe 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -967,18 +967,17 @@ def main(): | |||
967 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
968 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |
969 | else: | 969 | else: |
970 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 970 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
971 | 971 | ||
972 | if response.lower().strip() == "o": | 972 | if response.lower().strip() == "o": |
973 | lr_scheduler = "one_cycle" | 973 | lr_scheduler = "one_cycle" |
974 | lr_warmup_epochs = args.lr_warmup_epochs | 974 | lr_warmup_epochs = args.lr_warmup_epochs |
975 | lr_cycles = args.lr_cycles | 975 | lr_cycles = args.lr_cycles |
976 | if response.lower().strip() == "w": | 976 | if response.lower().strip() == "w": |
977 | lr_scheduler = "constant" | 977 | lr_scheduler = "constant_with_warmup" |
978 | lr_warmup_epochs = num_train_epochs | 978 | lr_warmup_epochs = num_train_epochs |
979 | if response.lower().strip() == "c": | 979 | if response.lower().strip() == "c": |
980 | lr_scheduler = "constant" | 980 | lr_scheduler = "constant" |
981 | lr_warmup_epochs = 0 | ||
982 | if response.lower().strip() == "d": | 981 | if response.lower().strip() == "d": |
983 | lr_scheduler = "cosine" | 982 | lr_scheduler = "cosine" |
984 | lr_warmup_epochs = 0 | 983 | lr_warmup_epochs = 0 |
diff --git a/train_ti.py b/train_ti.py index 880320f..b00b0d7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -925,18 +925,18 @@ def main(): | |||
925 | if len(auto_cycles) != 0: | 925 | if len(auto_cycles) != 0: |
926 | response = auto_cycles.pop(0) | 926 | response = auto_cycles.pop(0) |
927 | else: | 927 | else: |
928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | response = input( |
929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | ||
929 | 930 | ||
930 | if response.lower().strip() == "o": | 931 | if response.lower().strip() == "o": |
931 | lr_scheduler = "one_cycle" | 932 | lr_scheduler = "one_cycle" |
932 | lr_warmup_epochs = args.lr_warmup_epochs | 933 | lr_warmup_epochs = args.lr_warmup_epochs |
933 | lr_cycles = args.lr_cycles | 934 | lr_cycles = args.lr_cycles |
934 | if response.lower().strip() == "w": | 935 | if response.lower().strip() == "w": |
935 | lr_scheduler = "constant" | 936 | lr_scheduler = "constant_with_warmup" |
936 | lr_warmup_epochs = num_train_epochs | 937 | lr_warmup_epochs = num_train_epochs |
937 | if response.lower().strip() == "c": | 938 | if response.lower().strip() == "c": |
938 | lr_scheduler = "constant" | 939 | lr_scheduler = "constant" |
939 | lr_warmup_epochs = 0 | ||
940 | if response.lower().strip() == "d": | 940 | if response.lower().strip() == "d": |
941 | lr_scheduler = "cosine" | 941 | lr_scheduler = "cosine" |
942 | lr_warmup_epochs = 0 | 942 | lr_warmup_epochs = 0 |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bbff64..f330cb7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks( | |||
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] or lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
123 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
124 | 124 | ||
125 | if lambda_ != 0: | 125 | if lambda_ != 0: |