From 71f4a40bb48be4f2759ba2d83faff39691cb2955 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 19:03:25 +0200 Subject: Improved automation caps --- training/strategy/ti.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -103,11 +103,29 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield + @torch.no_grad() + def on_before_optimize(epoch: int): + if use_emb_decay: + params = [ + p + for p in text_encoder.text_model.embeddings.token_embedding.parameters() + if p.grad is not None + ] + return torch.stack(params) if len(params) != 0 else None + @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + if use_emb_decay and w is not None: + lr = lrs["emb"] or lrs["0"] + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( ) @torch.no_grad() - def on_sample(step): + def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) @@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) @@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( return TrainingCallbacks( on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-54-g00ecf