From 0e4c36889aa6b7ec13320a03728118c7c1a8e716 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 07:15:46 +0200 Subject: Sparse TI embeddings without sparse tensors --- models/clip/embeddings.py | 40 +++++++++++++++++++++++----------------- training/strategy/ti.py | 18 ++++++++---------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 870ee49..95904cf 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -42,20 +42,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.init_temp_embeddings() def init_temp_embeddings(self): - self.temp_token_embedding = nn.Embedding( - 0, - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype - ) + self.temp_token_embedding = nn.ParameterList() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - self.temp_token_embedding = resize_embedding( - self.temp_token_embedding, - size - self.num_permanent_embeddings, - self.initializer_factor - ) + for _ in range(len(self.temp_token_embedding), size): + self.temp_token_embedding.append(torch.zeros( + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, + )) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): @@ -74,14 +70,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): with torch.no_grad(): initializer = self.get_embed(initializer) + initializer = initializer.to( + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, + ) + token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) - self.temp_token_embedding.weight.data[mask] = initializer.to( - device=self.temp_token_embedding.weight.device, - dtype=self.temp_token_embedding.weight.dtype, - ) + + for i, id in enumerate(mask): + self.temp_token_embedding[id] = initializer[i] def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -91,7 +91,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): + self.token_embedding.weight.data[id] = emb self.num_permanent_embeddings = self.token_embedding.num_embeddings self.init_temp_embeddings() @@ -110,7 +111,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): all_temp_token_ids = all_temp_token_ids.unsqueeze(0) temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) + if len(temp_token_ids): + embeds_override = torch.stack([ + self.temp_token_embedding[id] + for id in temp_token_ids + ]) + embeds[embeds_mask] = embeds_override return embeds diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..7ac5011 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - return torch.all(w.grad == 0, dim=1) + return torch.stack([ + t + for t in text_encoder.text_model.embeddings.temp_token_embedding + if t.grad is not None + ]) @torch.no_grad() - def on_after_optimize(zero_ids, lr: float): + def on_after_optimize(w, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks( lambda_ = emb_decay * lr if lambda_ != 0: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - - mask = torch.ones(w.shape[0], dtype=torch.bool) - mask[zero_ids] = False - - norm = w[mask, :].norm(dim=-1, keepdim=True) - w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) def on_log(): if ema_embeddings is not None: -- cgit v1.2.3-70-g09d2