summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py109
-rw-r--r--models/clip/prompt.py6
-rw-r--r--models/clip/tokenizer.py64
3 files changed, 176 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
new file mode 100644
index 0000000..7d63ffb
--- /dev/null
+++ b/models/clip/embeddings.py
@@ -0,0 +1,109 @@
1from typing import Union, Optional
2from pathlib import Path
3
4import torch
5import torch.nn as nn
6
7from safetensors import safe_open
8from safetensors.torch import save_file
9
10from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13
14
15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size()
17
18 new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim)
19 new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype)
20 new_embedding.weight.data.zero_()
21 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data
22
23 return new_embedding
24
25
26class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
27 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
28 super().__init__(config)
29
30 self.token_embedding = embeddings.token_embedding
31 self.position_embedding = embeddings.position_embedding
32
33 self.temp_token_embedding = nn.Embedding(
34 self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
35 self.temp_token_embedding.weight.data.zero_()
36 self.temp_token_ids = torch.tensor([])
37
38 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
39 if isinstance(token_ids, int):
40 token_ids = [token_ids]
41
42 if initializer is not None:
43 if isinstance(initializer, int):
44 initializer = [initializer]
45
46 if isinstance(initializer, list):
47 initializer = (initializer * len(token_ids))[:len(token_ids)]
48
49 with torch.no_grad():
50 initializer = self.get_embed(initializer)
51
52 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
53 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))
54
55 token_ids = torch.tensor(token_ids)
56
57 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
58
59 if initializer is not None:
60 self.temp_token_embedding.weight.data[token_ids] = initializer
61 else:
62 self.temp_token_embedding.weight.data[token_ids].zero_()
63
64 def load_embed(self, input_ids: list[int], filename: Path):
65 with safe_open(filename, framework="pt", device="cpu") as file:
66 self.add_embed(input_ids, file.get_tensor("embed"))
67
68 def save_embed(self, input_ids: list[int], filename: Path):
69 save_file({"embed": self.get_embed(input_ids)}, filename)
70
71 def make_permanent(self):
72 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
73 self.temp_token_ids = torch.tensor([])
74
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids)
78
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device))
80
81 embeds = self.token_embedding(input_ids)
82 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
83
84 return embeds
85
86 def forward(
87 self,
88 input_ids: Optional[torch.LongTensor] = None,
89 position_ids: Optional[torch.LongTensor] = None,
90 inputs_embeds: Optional[torch.FloatTensor] = None,
91 ) -> torch.Tensor:
92 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
93
94 if position_ids is None:
95 position_ids = self.position_ids[:, :seq_length]
96
97 if inputs_embeds is None:
98 inputs_embeds = self.get_embed(input_ids)
99
100 position_embeddings = self.position_embedding(position_ids)
101 embeddings = inputs_embeds + position_embeddings
102
103 return embeddings
104
105
106def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
107 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
108 text_encoder.text_model.embeddings = text_embeddings
109 return text_embeddings
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index da33ecf..9da3955 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import List, Union 1from typing import Union
2 2
3import torch 3import torch
4 4
@@ -10,13 +10,13 @@ class PromptProcessor():
10 self.tokenizer = tokenizer 10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder 11 self.text_encoder = text_encoder
12 12
13 def get_input_ids(self, prompt: Union[str, List[str]]): 13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer( 14 return self.tokenizer(
15 prompt, 15 prompt,
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: List[int]): 19 def unify_input_ids(self, input_ids: list[int]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
new file mode 100644
index 0000000..78871db
--- /dev/null
+++ b/models/clip/tokenizer.py
@@ -0,0 +1,64 @@
1import copy
2from typing import NamedTuple, Union
3
4import numpy as np
5
6from transformers import CLIPTokenizer
7
8
9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str
11 placeholder_id: int
12 multi_ids: list[int]
13
14
15class MultiCLIPTokenizer(CLIPTokenizer):
16 def __init__(self, *args, **kwargs):
17 super().__init__(*args, **kwargs)
18 self.token_map: dict[int, list[int]] = {}
19
20 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
21 if isinstance(new_tokens, list):
22 if isinstance(num_vectors, int):
23 num_vectors = [num_vectors] * len(new_tokens)
24
25 if len(num_vectors) != len(new_tokens):
26 raise ValueError("Expected new_tokens and num_vectors to have the same len")
27
28 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]
29
30 if isinstance(num_vectors, list):
31 raise ValueError("Expected num_vectors to be int for single token")
32
33 super().add_tokens(new_tokens)
34
35 if num_vectors == 1:
36 multi_token = [new_tokens]
37 else:
38 multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
39 super().add_tokens(multi_token)
40
41 meta_id = super().convert_tokens_to_ids(new_tokens)
42 multi_ids = super().convert_tokens_to_ids(multi_token)
43
44 self.token_map[meta_id] = multi_ids
45
46 return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids)
47
48 def encode(self, *args, vector_shuffle=True, **kwargs):
49 ids = super().encode(*args, **kwargs)
50 new_ids = []
51
52 for id in ids:
53 if id in self.token_map:
54 tokens = self.token_map[id]
55
56 if vector_shuffle:
57 tokens = copy.copy(tokens)
58 np.random.shuffle(tokens)
59
60 new_ids = new_ids + self.token_map[id]
61 else:
62 new_ids.append(id)
63
64 return new_ids