From d488f66c78e444d03c4ef8a957b82f8b239379d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:31:24 +0200 Subject: Fix --- training/strategy/ti.py | 19 ------------------- 1 file changed, 19 deletions(-) (limited to 'training') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 49236c6..f0b84b5 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -103,29 +103,11 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - @torch.no_grad() - def on_before_optimize(epoch: int): - if use_emb_decay: - params = [ - p - for p in text_encoder.text_model.embeddings.token_embedding.parameters() - if p.grad is not None - ] - return torch.stack(params) if len(params) != 0 else None - @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) - if use_emb_decay and w is not None: - lr = lrs["emb"] if "emb" in lrs else lrs["0"] - lambda_ = emb_decay * lr - - if lambda_ != 0: - norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) - def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks( return TrainingCallbacks( on_train=on_train, on_eval=on_eval, - on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2