summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 912ff26..89269c0 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -79,10 +79,14 @@ def lora_strategy_callbacks(
79 tokenizer.eval() 79 tokenizer.eval()
80 yield 80 yield
81 81
82 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(epoch: int):
83 if not pti_mode: 83 if not pti_mode:
84 accelerator.clip_grad_norm_( 84 accelerator.clip_grad_norm_(
85 itertools.chain(unet.parameters(), text_encoder.parameters()), 85 itertools.chain(
86 unet.parameters(),
87 text_encoder.text_model.encoder.parameters(),
88 text_encoder.text_model.final_layer_norm.parameters(),
89 ),
86 max_grad_norm 90 max_grad_norm
87 ) 91 )
88 92
@@ -95,7 +99,9 @@ def lora_strategy_callbacks(
95 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
96 100
97 @torch.no_grad() 101 @torch.no_grad()
98 def on_after_optimize(w, lr: float): 102 def on_after_optimize(w, lrs: dict[str, float]):
103 lr = lrs["emb"] or lrs["0"]
104
99 if use_emb_decay and w is not None: 105 if use_emb_decay and w is not None:
100 lambda_ = emb_decay * lr 106 lambda_ = emb_decay * lr
101 107