diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 109 | ||||
| -rw-r--r-- | models/clip/prompt.py | 6 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 64 |
3 files changed, 176 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py | |||
| @@ -0,0 +1,109 @@ | |||
| 1 | from typing import Union, Optional | ||
| 2 | from pathlib import Path | ||
| 3 | |||
| 4 | import torch | ||
| 5 | import torch.nn as nn | ||
| 6 | |||
| 7 | from safetensors import safe_open | ||
| 8 | from safetensors.torch import save_file | ||
| 9 | |||
| 10 | from transformers import CLIPTextModel | ||
| 11 | from transformers.models.clip import CLIPTextConfig | ||
| 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
| 13 | |||
| 14 | |||
| 15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | ||
| 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | ||
| 17 | |||
| 18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | ||
| 19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | ||
| 20 | new_embedding.weight.data.zero_() | ||
| 21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | ||
| 22 | |||
| 23 | return new_embedding | ||
| 24 | |||
| 25 | |||
| 26 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | ||
| 27 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | ||
| 28 | super().__init__(config) | ||
| 29 | |||
| 30 | self.token_embedding = embeddings.token_embedding | ||
| 31 | self.position_embedding = embeddings.position_embedding | ||
| 32 | |||
| 33 | self.temp_token_embedding = nn.Embedding( | ||
| 34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
| 35 | self.temp_token_embedding.weight.data.zero_() | ||
| 36 | self.temp_token_ids = torch.tensor([]) | ||
| 37 | |||
| 38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | ||
| 39 | if isinstance(token_ids, int): | ||
| 40 | token_ids = [token_ids] | ||
| 41 | |||
| 42 | if initializer is not None: | ||
| 43 | if isinstance(initializer, int): | ||
| 44 | initializer = [initializer] | ||
| 45 | |||
| 46 | if isinstance(initializer, list): | ||
| 47 | initializer = (initializer * len(token_ids))[:len(token_ids)] | ||
| 48 | |||
| 49 | with torch.no_grad(): | ||
| 50 | initializer = self.get_embed(initializer) | ||
| 51 | |||
| 52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
| 53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
| 54 | |||
| 55 | token_ids = torch.tensor(token_ids) | ||
| 56 | |||
| 57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
| 58 | |||
| 59 | if initializer is not None: | ||
| 60 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 61 | else: | ||
| 62 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
| 63 | |||
| 64 | def load_embed(self, input_ids: list[int], filename: Path): | ||
| 65 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 66 | self.add_embed(input_ids, file.get_tensor("embed")) | ||
| 67 | |||
| 68 | def save_embed(self, input_ids: list[int], filename: Path): | ||
| 69 | save_file({"embed": self.get_embed(input_ids)}, filename) | ||
| 70 | |||
| 71 | def make_permanent(self): | ||
| 72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | ||
| 73 | self.temp_token_ids = torch.tensor([]) | ||
| 74 | |||
| 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | ||
| 76 | if isinstance(input_ids, list): | ||
| 77 | input_ids = torch.tensor(input_ids) | ||
| 78 | |||
| 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | ||
| 80 | |||
| 81 | embeds = self.token_embedding(input_ids) | ||
| 82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | ||
| 83 | |||
| 84 | return embeds | ||
| 85 | |||
| 86 | def forward( | ||
| 87 | self, | ||
| 88 | input_ids: Optional[torch.LongTensor] = None, | ||
| 89 | position_ids: Optional[torch.LongTensor] = None, | ||
| 90 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 91 | ) -> torch.Tensor: | ||
| 92 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 93 | |||
| 94 | if position_ids is None: | ||
| 95 | position_ids = self.position_ids[:, :seq_length] | ||
| 96 | |||
| 97 | if inputs_embeds is None: | ||
| 98 | inputs_embeds = self.get_embed(input_ids) | ||
| 99 | |||
| 100 | position_embeddings = self.position_embedding(position_ids) | ||
| 101 | embeddings = inputs_embeds + position_embeddings | ||
| 102 | |||
| 103 | return embeddings | ||
| 104 | |||
| 105 | |||
| 106 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | ||
| 107 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | ||
| 108 | text_encoder.text_model.embeddings = text_embeddings | ||
| 109 | return text_embeddings | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import List, Union | 1 | from typing import Union |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
| @@ -10,13 +10,13 @@ class PromptProcessor(): | |||
| 10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
| 11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
| 12 | 12 | ||
| 13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
| 14 | return self.tokenizer( | 14 | return self.tokenizer( |
| 15 | prompt, | 15 | prompt, |
| 16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
| 17 | ).input_ids | 17 | ).input_ids |
| 18 | 18 | ||
| 19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
| 20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
| 21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
| 22 | padding=True, | 22 | padding=True, |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
| @@ -0,0 +1,64 @@ | |||
| 1 | import copy | ||
| 2 | from typing import NamedTuple, Union | ||
| 3 | |||
| 4 | import numpy as np | ||
| 5 | |||
| 6 | from transformers import CLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
| 10 | token: str | ||
| 11 | placeholder_id: int | ||
| 12 | multi_ids: list[int] | ||
| 13 | |||
| 14 | |||
| 15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
| 16 | def __init__(self, *args, **kwargs): | ||
| 17 | super().__init__(*args, **kwargs) | ||
| 18 | self.token_map: dict[int, list[int]] = {} | ||
| 19 | |||
| 20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
| 21 | if isinstance(new_tokens, list): | ||
| 22 | if isinstance(num_vectors, int): | ||
| 23 | num_vectors = [num_vectors] * len(new_tokens) | ||
| 24 | |||
| 25 | if len(num_vectors) != len(new_tokens): | ||
| 26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
| 27 | |||
| 28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
| 29 | |||
| 30 | if isinstance(num_vectors, list): | ||
| 31 | raise ValueError("Expected num_vectors to be int for single token") | ||
| 32 | |||
| 33 | super().add_tokens(new_tokens) | ||
| 34 | |||
| 35 | if num_vectors == 1: | ||
| 36 | multi_token = [new_tokens] | ||
| 37 | else: | ||
| 38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
| 39 | super().add_tokens(multi_token) | ||
| 40 | |||
| 41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
| 42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
| 43 | |||
| 44 | self.token_map[meta_id] = multi_ids | ||
| 45 | |||
| 46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
| 47 | |||
| 48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
| 49 | ids = super().encode(*args, **kwargs) | ||
| 50 | new_ids = [] | ||
| 51 | |||
| 52 | for id in ids: | ||
| 53 | if id in self.token_map: | ||
| 54 | tokens = self.token_map[id] | ||
| 55 | |||
| 56 | if vector_shuffle: | ||
| 57 | tokens = copy.copy(tokens) | ||
| 58 | np.random.shuffle(tokens) | ||
| 59 | |||
| 60 | new_ids = new_ids + self.token_map[id] | ||
| 61 | else: | ||
| 62 | new_ids.append(id) | ||
| 63 | |||
| 64 | return new_ids | ||
