diff options
Diffstat (limited to 'training/strategy/lora.py')
| -rw-r--r-- | training/strategy/lora.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
| 79 | tokenizer.eval() | 79 | tokenizer.eval() |
| 80 | yield | 80 | yield |
| 81 | 81 | ||
| 82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
| 83 | if not pti_mode: | 83 | if not pti_mode: |
| 84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
| 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
| 86 | unet.parameters(), | ||
| 87 | text_encoder.text_model.encoder.parameters(), | ||
| 88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 89 | ), | ||
| 86 | max_grad_norm | 90 | max_grad_norm |
| 87 | ) | 91 | ) |
| 88 | 92 | ||
| @@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
| 95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| 96 | 100 | ||
| 97 | @torch.no_grad() | 101 | @torch.no_grad() |
| 98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
| 103 | lr = lrs["emb"] or lrs["0"] | ||
| 104 | |||
| 99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
| 100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
| 101 | 107 | ||
