summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py29
-rw-r--r--models/clip/tokenizer.py23
-rw-r--r--models/clip/util.py17
3 files changed, 41 insertions, 28 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 7c7f2ac..8c3c6d4 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding
14 14
15 15
16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): 17 def __init__(
18 self,
19 config: CLIPTextConfig,
20 embeddings: CLIPTextEmbeddings,
21 alpha: int = 8,
22 dropout: float = 0.0,
23 ):
18 super().__init__(config) 24 super().__init__(config)
19 25
20 self.position_embedding = embeddings.position_embedding 26 self.position_embedding = embeddings.position_embedding
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
28 self.token_embedding.weight = embeddings.token_embedding.weight 34 self.token_embedding.weight = embeddings.token_embedding.weight
29 35
30 def resize(self, size: int): 36 def resize(self, size: int):
31 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) 37 self.token_embedding = self.token_embedding.new_resized(
38 size, self.initializer_factor
39 )
32 40
33 def add_embed( 41 def add_embed(
34 self, 42 self,
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 initializer = [initializer] 54 initializer = [initializer]
47 55
48 if isinstance(initializer, list): 56 if isinstance(initializer, list):
49 initializer = (initializer * len(token_ids))[:len(token_ids)] 57 initializer = (initializer * len(token_ids))[: len(token_ids)]
50 58
51 with torch.no_grad(): 59 with torch.no_grad():
52 initializer = self.get_embed(initializer) 60 initializer = self.get_embed(initializer)
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 84
77 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
78 if isinstance(input_ids, list): 86 if isinstance(input_ids, list):
79 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 87 input_ids = torch.tensor(
88 input_ids, device=self.token_embedding.weight.device, dtype=torch.long
89 )
80 90
81 return self.token_embedding(input_ids) 91 return self.token_embedding(input_ids)
82 92
83 93
84def patch_managed_embeddings( 94def patch_managed_embeddings(
85 text_encoder: CLIPTextModel, 95 text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0
86 alpha: int = 8,
87 dropout: float = 0.0
88) -> ManagedCLIPTextEmbeddings: 96) -> ManagedCLIPTextEmbeddings:
89 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): 97 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
90 return text_encoder.text_model.embeddings 98 return text_encoder.text_model.embeddings
91 99
92 text_embeddings = ManagedCLIPTextEmbeddings( 100 text_embeddings = ManagedCLIPTextEmbeddings(
93 text_encoder.config, 101 text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout
94 text_encoder.text_model.embeddings,
95 alpha,
96 dropout
97 ) 102 )
98 text_encoder.text_model.embeddings = text_embeddings 103 text_encoder.text_model.embeddings = text_embeddings
99 return text_embeddings 104 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 789b525..a866641 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
91 self.vector_shuffle = shuffle_none 91 self.vector_shuffle = shuffle_none
92 92
93 def add_multi_tokens( 93 def add_multi_tokens(
94 self, 94 self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1
95 new_tokens: Union[str, list[str]],
96 num_vectors: Union[int, list[int]] = 1
97 ) -> Union[list[int], list[list[int]]]: 95 ) -> Union[list[int], list[list[int]]]:
98 if isinstance(new_tokens, list): 96 if isinstance(new_tokens, list):
99 if isinstance(num_vectors, int): 97 if isinstance(num_vectors, int):
100 num_vectors = [num_vectors] * len(new_tokens) 98 num_vectors = [num_vectors] * len(new_tokens)
101 99
102 if len(num_vectors) != len(new_tokens): 100 if len(num_vectors) != len(new_tokens):
103 raise ValueError("Expected new_tokens and num_vectors to have the same len") 101 raise ValueError(
102 "Expected new_tokens and num_vectors to have the same len"
103 )
104 104
105 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] 105 return [
106 self.add_multi_tokens(new_token, vecs)
107 for new_token, vecs in zip(new_tokens, num_vectors)
108 ]
106 109
107 if isinstance(num_vectors, list): 110 if isinstance(num_vectors, list):
108 raise ValueError("Expected num_vectors to be int for single token") 111 raise ValueError("Expected num_vectors to be int for single token")
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
129 return [id] 132 return [id]
130 133
131 def expand_ids(self, ids: list[int]): 134 def expand_ids(self, ids: list[int]):
132 return [ 135 return [new_id for id in ids for new_id in self.expand_id(id)]
133 new_id
134 for id in ids
135 for new_id in self.expand_id(id)
136 ]
137 136
138 def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): 137 def expand_batched_ids(
138 self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]
139 ):
139 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): 140 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
140 return [self.expand_ids(batch) for batch in input_ids] 141 return [self.expand_ids(batch) for batch in input_ids]
141 else: 142 else:
diff --git a/models/clip/util.py b/models/clip/util.py
index f94fbc7..7196bb6 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -5,27 +5,32 @@ import torch
5from transformers import CLIPTokenizer, CLIPTextModel 5from transformers import CLIPTokenizer, CLIPTextModel
6 6
7 7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): 8def unify_input_ids(
9 tokenizer: CLIPTokenizer,
10 input_ids: list[list[int]],
11 max_length: Optional[int] = None,
12):
9 if max_length is None: 13 if max_length is None:
10 return tokenizer.pad( 14 return tokenizer.pad(
11 {"input_ids": input_ids}, 15 {"input_ids": input_ids},
12 padding=True, 16 padding=True,
13 pad_to_multiple_of=tokenizer.model_max_length, 17 pad_to_multiple_of=tokenizer.model_max_length,
14 return_tensors="pt" 18 return_tensors="pt",
15 ) 19 )
16 else: 20 else:
17 return tokenizer.pad( 21 return tokenizer.pad(
18 {"input_ids": input_ids}, 22 {"input_ids": input_ids},
19 padding="max_length", 23 padding="max_length",
20 max_length=max_length, 24 max_length=max_length,
21 return_tensors="pt" 25 return_tensors="pt",
22 ) 26 )
23 27
28
24def get_extended_embeddings( 29def get_extended_embeddings(
25 text_encoder: CLIPTextModel, 30 text_encoder: CLIPTextModel,
26 input_ids: torch.LongTensor, 31 input_ids: torch.LongTensor,
27 position_ids: Optional[torch.LongTensor] = None, 32 position_ids: Optional[torch.LongTensor] = None,
28 attention_mask=None 33 attention_mask=None,
29): 34):
30 model_max_length = text_encoder.config.max_position_embeddings 35 model_max_length = text_encoder.config.max_position_embeddings
31 prompts = input_ids.shape[0] 36 prompts = input_ids.shape[0]
@@ -36,6 +41,8 @@ def get_extended_embeddings(
36 if attention_mask is not None: 41 if attention_mask is not None:
37 attention_mask = attention_mask.view((-1, model_max_length)) 42 attention_mask = attention_mask.view((-1, model_max_length))
38 43
39 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] 44 text_embeddings = text_encoder(
45 input_ids, position_ids=position_ids, attention_mask=attention_mask
46 )[0]
40 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 47 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
41 return text_embeddings 48 return text_embeddings