From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- models/clip/embeddings.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++ models/clip/prompt.py | 6 +-- models/clip/tokenizer.py | 64 +++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 3 deletions(-) create mode 100644 models/clip/embeddings.py create mode 100644 models/clip/tokenizer.py (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py @@ -0,0 +1,109 @@ +from typing import Union, Optional +from pathlib import Path + +import torch +import torch.nn as nn + +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import CLIPTextModel +from transformers.models.clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPTextEmbeddings + + +def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: + old_num_embeddings, old_embedding_dim = old_embedding.weight.size() + + new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) + new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) + new_embedding.weight.data.zero_() + new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data + + return new_embedding + + +class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + super().__init__(config) + + self.token_embedding = embeddings.token_embedding + self.position_embedding = embeddings.position_embedding + + self.temp_token_embedding = nn.Embedding( + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + self.temp_token_embedding.weight.data.zero_() + self.temp_token_ids = torch.tensor([]) + + def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): + if isinstance(token_ids, int): + token_ids = [token_ids] + + if initializer is not None: + if isinstance(initializer, int): + initializer = [initializer] + + if isinstance(initializer, list): + initializer = (initializer * len(token_ids))[:len(token_ids)] + + with torch.no_grad(): + initializer = self.get_embed(initializer) + + self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) + self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) + + token_ids = torch.tensor(token_ids) + + self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) + + if initializer is not None: + self.temp_token_embedding.weight.data[token_ids] = initializer + else: + self.temp_token_embedding.weight.data[token_ids].zero_() + + def load_embed(self, input_ids: list[int], filename: Path): + with safe_open(filename, framework="pt", device="cpu") as file: + self.add_embed(input_ids, file.get_tensor("embed")) + + def save_embed(self, input_ids: list[int], filename: Path): + save_file({"embed": self.get_embed(input_ids)}, filename) + + def make_permanent(self): + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.temp_token_ids = torch.tensor([]) + + def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): + if isinstance(input_ids, list): + input_ids = torch.tensor(input_ids) + + mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) + + embeds = self.token_embedding(input_ids) + embeds[mask] = self.temp_token_embedding(input_ids)[mask] + + return embeds + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.get_embed(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) + text_encoder.text_model.embeddings = text_embeddings + return text_embeddings diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import torch @@ -10,13 +10,13 @@ class PromptProcessor(): self.tokenizer = tokenizer self.text_encoder = text_encoder - def get_input_ids(self, prompt: Union[str, List[str]]): + def get_input_ids(self, prompt: Union[str, list[str]]): return self.tokenizer( prompt, padding="do_not_pad", ).input_ids - def unify_input_ids(self, input_ids: List[int]): + def unify_input_ids(self, input_ids: list[int]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py @@ -0,0 +1,64 @@ +import copy +from typing import NamedTuple, Union + +import numpy as np + +from transformers import CLIPTokenizer + + +class MultiCLIPTokenizerItem(NamedTuple): + token: str + placeholder_id: int + multi_ids: list[int] + + +class MultiCLIPTokenizer(CLIPTokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.token_map: dict[int, list[int]] = {} + + def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: + if isinstance(new_tokens, list): + if isinstance(num_vectors, int): + num_vectors = [num_vectors] * len(new_tokens) + + if len(num_vectors) != len(new_tokens): + raise ValueError("Expected new_tokens and num_vectors to have the same len") + + return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] + + if isinstance(num_vectors, list): + raise ValueError("Expected num_vectors to be int for single token") + + super().add_tokens(new_tokens) + + if num_vectors == 1: + multi_token = [new_tokens] + else: + multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] + super().add_tokens(multi_token) + + meta_id = super().convert_tokens_to_ids(new_tokens) + multi_ids = super().convert_tokens_to_ids(multi_token) + + self.token_map[meta_id] = multi_ids + + return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) + + def encode(self, *args, vector_shuffle=True, **kwargs): + ids = super().encode(*args, **kwargs) + new_ids = [] + + for id in ids: + if id in self.token_map: + tokens = self.token_map[id] + + if vector_shuffle: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) + + new_ids = new_ids + self.token_map[id] + else: + new_ids.append(id) + + return new_ids -- cgit v1.2.3-54-g00ecf