diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 109 | ||||
-rw-r--r-- | models/clip/prompt.py | 6 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 64 |
3 files changed, 176 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py | |||
@@ -0,0 +1,109 @@ | |||
1 | from typing import Union, Optional | ||
2 | from pathlib import Path | ||
3 | |||
4 | import torch | ||
5 | import torch.nn as nn | ||
6 | |||
7 | from safetensors import safe_open | ||
8 | from safetensors.torch import save_file | ||
9 | |||
10 | from transformers import CLIPTextModel | ||
11 | from transformers.models.clip import CLIPTextConfig | ||
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
13 | |||
14 | |||
15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | ||
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | ||
17 | |||
18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | ||
19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | ||
20 | new_embedding.weight.data.zero_() | ||
21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | ||
22 | |||
23 | return new_embedding | ||
24 | |||
25 | |||
26 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | ||
27 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | ||
28 | super().__init__(config) | ||
29 | |||
30 | self.token_embedding = embeddings.token_embedding | ||
31 | self.position_embedding = embeddings.position_embedding | ||
32 | |||
33 | self.temp_token_embedding = nn.Embedding( | ||
34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
35 | self.temp_token_embedding.weight.data.zero_() | ||
36 | self.temp_token_ids = torch.tensor([]) | ||
37 | |||
38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | ||
39 | if isinstance(token_ids, int): | ||
40 | token_ids = [token_ids] | ||
41 | |||
42 | if initializer is not None: | ||
43 | if isinstance(initializer, int): | ||
44 | initializer = [initializer] | ||
45 | |||
46 | if isinstance(initializer, list): | ||
47 | initializer = (initializer * len(token_ids))[:len(token_ids)] | ||
48 | |||
49 | with torch.no_grad(): | ||
50 | initializer = self.get_embed(initializer) | ||
51 | |||
52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
54 | |||
55 | token_ids = torch.tensor(token_ids) | ||
56 | |||
57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
58 | |||
59 | if initializer is not None: | ||
60 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
61 | else: | ||
62 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
63 | |||
64 | def load_embed(self, input_ids: list[int], filename: Path): | ||
65 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
66 | self.add_embed(input_ids, file.get_tensor("embed")) | ||
67 | |||
68 | def save_embed(self, input_ids: list[int], filename: Path): | ||
69 | save_file({"embed": self.get_embed(input_ids)}, filename) | ||
70 | |||
71 | def make_permanent(self): | ||
72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | ||
73 | self.temp_token_ids = torch.tensor([]) | ||
74 | |||
75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | ||
76 | if isinstance(input_ids, list): | ||
77 | input_ids = torch.tensor(input_ids) | ||
78 | |||
79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | ||
80 | |||
81 | embeds = self.token_embedding(input_ids) | ||
82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | ||
83 | |||
84 | return embeds | ||
85 | |||
86 | def forward( | ||
87 | self, | ||
88 | input_ids: Optional[torch.LongTensor] = None, | ||
89 | position_ids: Optional[torch.LongTensor] = None, | ||
90 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
91 | ) -> torch.Tensor: | ||
92 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
93 | |||
94 | if position_ids is None: | ||
95 | position_ids = self.position_ids[:, :seq_length] | ||
96 | |||
97 | if inputs_embeds is None: | ||
98 | inputs_embeds = self.get_embed(input_ids) | ||
99 | |||
100 | position_embeddings = self.position_embedding(position_ids) | ||
101 | embeddings = inputs_embeds + position_embeddings | ||
102 | |||
103 | return embeddings | ||
104 | |||
105 | |||
106 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | ||
107 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | ||
108 | text_encoder.text_model.embeddings = text_embeddings | ||
109 | return text_embeddings | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -1,4 +1,4 @@ | |||
1 | from typing import List, Union | 1 | from typing import Union |
2 | 2 | ||
3 | import torch | 3 | import torch |
4 | 4 | ||
@@ -10,13 +10,13 @@ class PromptProcessor(): | |||
10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
12 | 12 | ||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
14 | return self.tokenizer( | 14 | return self.tokenizer( |
15 | prompt, | 15 | prompt, |
16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
17 | ).input_ids | 17 | ).input_ids |
18 | 18 | ||
19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
22 | padding=True, | 22 | padding=True, |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
@@ -0,0 +1,64 @@ | |||
1 | import copy | ||
2 | from typing import NamedTuple, Union | ||
3 | |||
4 | import numpy as np | ||
5 | |||
6 | from transformers import CLIPTokenizer | ||
7 | |||
8 | |||
9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
10 | token: str | ||
11 | placeholder_id: int | ||
12 | multi_ids: list[int] | ||
13 | |||
14 | |||
15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
16 | def __init__(self, *args, **kwargs): | ||
17 | super().__init__(*args, **kwargs) | ||
18 | self.token_map: dict[int, list[int]] = {} | ||
19 | |||
20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
21 | if isinstance(new_tokens, list): | ||
22 | if isinstance(num_vectors, int): | ||
23 | num_vectors = [num_vectors] * len(new_tokens) | ||
24 | |||
25 | if len(num_vectors) != len(new_tokens): | ||
26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
27 | |||
28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
29 | |||
30 | if isinstance(num_vectors, list): | ||
31 | raise ValueError("Expected num_vectors to be int for single token") | ||
32 | |||
33 | super().add_tokens(new_tokens) | ||
34 | |||
35 | if num_vectors == 1: | ||
36 | multi_token = [new_tokens] | ||
37 | else: | ||
38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
39 | super().add_tokens(multi_token) | ||
40 | |||
41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
43 | |||
44 | self.token_map[meta_id] = multi_ids | ||
45 | |||
46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
47 | |||
48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
49 | ids = super().encode(*args, **kwargs) | ||
50 | new_ids = [] | ||
51 | |||
52 | for id in ids: | ||
53 | if id in self.token_map: | ||
54 | tokens = self.token_map[id] | ||
55 | |||
56 | if vector_shuffle: | ||
57 | tokens = copy.copy(tokens) | ||
58 | np.random.shuffle(tokens) | ||
59 | |||
60 | new_ids = new_ids + self.token_map[id] | ||
61 | else: | ||
62 | new_ids.append(id) | ||
63 | |||
64 | return new_ids | ||