diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
92 | max_grad_norm | 92 | max_grad_norm |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
96 | params = [ | 96 | params = [ |
97 | p | 97 | p |
98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() |
@@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
102 | 102 | ||
103 | @torch.no_grad() | 103 | @torch.no_grad() |
104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): |
105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: |
106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] |
107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr |
108 | 108 | ||