diff options
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index d51a2f3..6730dc9 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -85,15 +85,16 @@ def lora_strategy_callbacks( | |||
85 | ) | 85 | ) |
86 | 86 | ||
87 | if use_emb_decay: | 87 | if use_emb_decay: |
88 | return torch.stack([ | 88 | params = [ |
89 | p | 89 | p |
90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params |
91 | if p.grad is not None | 91 | if p.grad is not None |
92 | ]) | 92 | ] |
93 | return torch.stack(params) if len(params) != 0 else None | ||
93 | 94 | ||
94 | @torch.no_grad() | 95 | @torch.no_grad() |
95 | def on_after_optimize(w, lr: float): | 96 | def on_after_optimize(w, lr: float): |
96 | if use_emb_decay: | 97 | if use_emb_decay and w is not None: |
97 | lambda_ = emb_decay * lr | 98 | lambda_ = emb_decay * lr |
98 | 99 | ||
99 | if lambda_ != 0: | 100 | if lambda_ != 0: |