From 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 07:58:14 +0200 Subject: Update --- training/strategy/dreambooth.py | 2 +- training/strategy/lora.py | 12 +++++++++--- training/strategy/ti.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -79,10 +79,14 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if not pti_mode: accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), + itertools.chain( + unet.parameters(), + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), max_grad_norm ) @@ -95,7 +99,9 @@ def lora_strategy_callbacks( return torch.stack(params) if len(params) != 0 else None @torch.no_grad() - def on_after_optimize(w, lr: float): + def on_after_optimize(w, lrs: dict[str, float]): + lr = lrs["emb"] or lrs["0"] + if use_emb_decay and w is not None: lambda_ = emb_decay * lr diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if use_emb_decay: params = [ p -- cgit v1.2.3-70-g09d2