From 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 07:58:14 +0200 Subject: Update --- training/strategy/lora.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -79,10 +79,14 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if not pti_mode: accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), + itertools.chain( + unet.parameters(), + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), max_grad_norm ) @@ -95,7 +99,9 @@ def lora_strategy_callbacks( return torch.stack(params) if len(params) != 0 else None @torch.no_grad() - def on_after_optimize(w, lr: float): + def on_after_optimize(w, lrs: dict[str, float]): + lr = lrs["emb"] or lrs["0"] + if use_emb_decay and w is not None: lambda_ = emb_decay * lr -- cgit v1.2.3-54-g00ecf