From 0e4c36889aa6b7ec13320a03728118c7c1a8e716 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 07:15:46 +0200 Subject: Sparse TI embeddings without sparse tensors --- training/strategy/ti.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..7ac5011 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - return torch.all(w.grad == 0, dim=1) + return torch.stack([ + t + for t in text_encoder.text_model.embeddings.temp_token_embedding + if t.grad is not None + ]) @torch.no_grad() - def on_after_optimize(zero_ids, lr: float): + def on_after_optimize(w, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks( lambda_ = emb_decay * lr if lambda_ != 0: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - - mask = torch.ones(w.shape[0], dtype=torch.bool) - mask[zero_ids] = False - - norm = w[mask, :].norm(dim=-1, keepdim=True) - w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) def on_log(): if ema_embeddings is not None: -- cgit v1.2.3-70-g09d2