From 2e654c017780d37f3304436e2feb84b619f1c023 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 22:25:20 +0200 Subject: Improved sparse embeddings --- models/clip/embeddings.py | 52 +++++++++++++++++------------------------- models/sparse.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++ train_ti.py | 2 +- training/strategy/ti.py | 8 +++---- 4 files changed, 83 insertions(+), 36 deletions(-) create mode 100644 models/sparse.py diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d8343a0..a356434 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -11,6 +11,8 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings +from models.sparse import PseudoSparseEmbedding + def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.shape @@ -41,18 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.alpha = alpha - self.temp_token_embedding = nn.ParameterList() - self.temp_token_ids = torch.tensor([], dtype=torch.long) + self.token_override_embedding = PseudoSparseEmbedding( + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, + ) + self.alpha = alpha def resize(self, size: int): - for _ in range(len(self.temp_token_embedding), size): - self.temp_token_embedding.append(torch.zeros( - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, - )) + self.token_override_embedding.resize(size) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -86,8 +86,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) - self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) self.token_embedding.weight.data[token_ids] = initializer + self.token_override_embedding.set(token_ids) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -97,33 +97,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): - self.token_embedding.weight.data[id] += self.alpha * emb - nn.init.zeros_(emb) - self.temp_token_ids = torch.tensor([], dtype=torch.long) + input_ids = torch.arange(self.token_embedding.num_embeddings) + embs, mask = self.token_override_embedding(input_ids) + if embs is not None: + input_ids = input_ids[mask] + self.token_embedding.weight.data[input_ids] += self.alpha * embs + self.token_override_embedding.unset(input_ids) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - all_temp_token_ids = self.temp_token_ids.to(input_ids.device) - - embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, all_temp_token_ids) - temp_token_ids = input_ids[mask] - - temp_token_ids = temp_token_ids.unsqueeze(1) - all_temp_token_ids = all_temp_token_ids.unsqueeze(0) - temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - - if len(temp_token_ids): - embeds_override = torch.stack([ - self.temp_token_embedding[id] - for id in temp_token_ids - ]) - embeds[mask] += self.alpha * embeds_override + embs = self.token_embedding(input_ids) + embs_override, mask = self.token_override_embedding(input_ids) + if embs_override is not None: + embs[mask] += self.alpha * embs_override - return embeds + return embs def forward( self, diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..0b15454 --- /dev/null +++ b/models/sparse.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch +import torch.nn as nn + + +class PseudoSparseEmbedding(nn.Module): + def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): + super().__init__() + + self.embedding_dim = embedding_dim + self.dtype = dtype + self.params = nn.ParameterList() + self.mapping = torch.zeros(0, device=device, dtype=torch.long) + + def forward(self, input_ids: Optional[torch.LongTensor] = None): + if input_ids is None: + input_ids = torch.arange(self.mapping.shape[0]) + + ids = self.mapping[input_ids.to(self.mapping.device)] + mask = ~(ids == -1) + + if torch.all(~mask): + embs = None + else: + embs = torch.stack([self.params[id] for id in ids[mask]]) + + return embs, mask + + def resize(self, new_num_embeddings: int): + old_num_embeddings = self.mapping.shape[0] + n = min(old_num_embeddings, new_num_embeddings) + + new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 + new_mapping[:n] = self.mapping[:n] + + self.mapping = new_mapping + + def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): + if len(input_ids.shape) != 0: + if tensor is not None: + return [self.set(id, t) for id, t in zip(input_ids, tensor)] + else: + return [self.set(id) for id in input_ids] + + id = self.mapping[input_ids] + + if id == -1: + id = len(self.params) + self.mapping[input_ids] = id + self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) + + self.params[id] = tensor if tensor is not None else torch.zeros( + self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + + def unset(self, input_ids: torch.LongTensor): + self.mapping[input_ids] = -1 diff --git a/train_ti.py b/train_ti.py index 0ad7574..a9a2333 100644 --- a/train_ti.py +++ b/train_ti.py @@ -809,7 +809,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), lr=args.learning_rate, ) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 16baa34..95128da 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() + text_encoder.text_model.embeddings.token_override_embedding.params.parameters() ) else: return nullcontext() def on_accum_model(): - return text_encoder.text_model.embeddings.temp_token_embedding + return text_encoder.text_model.embeddings.token_override_embedding.params @contextmanager def on_train(epoch: int): @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) def on_log(): if ema_embeddings is not None: -- cgit v1.2.3-54-g00ecf