diff options
author | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
commit | 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch) | |
tree | b1483a52fb853aecb7b73635cded3cce61edf125 /training/strategy/lora.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2 textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip |
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
79 | tokenizer.eval() | 79 | tokenizer.eval() |
80 | yield | 80 | yield |
81 | 81 | ||
82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
83 | if not pti_mode: | 83 | if not pti_mode: |
84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
86 | unet.parameters(), | ||
87 | text_encoder.text_model.encoder.parameters(), | ||
88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
89 | ), | ||
86 | max_grad_norm | 90 | max_grad_norm |
87 | ) | 91 | ) |
88 | 92 | ||
@@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
96 | 100 | ||
97 | @torch.no_grad() | 101 | @torch.no_grad() |
98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
103 | lr = lrs["emb"] or lrs["0"] | ||
104 | |||
99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
101 | 107 | ||