summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
commit5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch)
treeb1483a52fb853aecb7b73635cded3cce61edf125 /training/strategy/lora.py
parentFix (diff)
downloadtextual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 912ff26..89269c0 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -79,10 +79,14 @@ def lora_strategy_callbacks(
79 tokenizer.eval() 79 tokenizer.eval()
80 yield 80 yield
81 81
82 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(epoch: int):
83 if not pti_mode: 83 if not pti_mode:
84 accelerator.clip_grad_norm_( 84 accelerator.clip_grad_norm_(
85 itertools.chain(unet.parameters(), text_encoder.parameters()), 85 itertools.chain(
86 unet.parameters(),
87 text_encoder.text_model.encoder.parameters(),
88 text_encoder.text_model.final_layer_norm.parameters(),
89 ),
86 max_grad_norm 90 max_grad_norm
87 ) 91 )
88 92
@@ -95,7 +99,9 @@ def lora_strategy_callbacks(
95 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
96 100
97 @torch.no_grad() 101 @torch.no_grad()
98 def on_after_optimize(w, lr: float): 102 def on_after_optimize(w, lrs: dict[str, float]):
103 lr = lrs["emb"] or lrs["0"]
104
99 if use_emb_decay and w is not None: 105 if use_emb_decay and w is not None:
100 lambda_ = emb_decay * lr 106 lambda_ = emb_decay * lr
101 107