summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py27
-rw-r--r--models/clip/tokenizer.py39
2 files changed, 44 insertions, 22 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f82873e..91a575d 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -15,8 +15,12 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: 15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size() 16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size()
17 17
18 new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) 18 new_embedding = nn.Embedding(
19 new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) 19 old_num_embeddings + n,
20 old_embedding_dim,
21 device=old_embedding.weight.device,
22 dtype=old_embedding.weight.dtype
23 )
20 new_embedding.weight.data.zero_() 24 new_embedding.weight.data.zero_()
21 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data 25 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data
22 26
@@ -31,9 +35,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
31 self.position_embedding = embeddings.position_embedding 35 self.position_embedding = embeddings.position_embedding
32 36
33 self.temp_token_embedding = nn.Embedding( 37 self.temp_token_embedding = nn.Embedding(
34 self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 38 self.token_embedding.num_embeddings,
39 self.token_embedding.embedding_dim,
40 device=self.token_embedding.weight.device,
41 dtype=self.token_embedding.weight.dtype
42 )
35 self.temp_token_embedding.weight.data.zero_() 43 self.temp_token_embedding.weight.data.zero_()
36 self.temp_token_ids = torch.tensor([]) 44 self.temp_token_ids = torch.tensor([], dtype=torch.long)
37 45
38 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 46 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
39 if isinstance(token_ids, int): 47 if isinstance(token_ids, int):
@@ -52,12 +60,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
52 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) 60 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
53 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) 61 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))
54 62
55 token_ids = torch.tensor(token_ids) 63 token_ids = torch.tensor(token_ids, dtype=torch.long)
56 64
57 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 65 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
58 66
59 if initializer is not None: 67 if initializer is not None:
60 self.temp_token_embedding.weight.data[token_ids] = initializer 68 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
69 dtype=self.temp_token_embedding.weight.dtype)
61 else: 70 else:
62 self.temp_token_embedding.weight.data[token_ids].zero_() 71 self.temp_token_embedding.weight.data[token_ids].zero_()
63 72
@@ -70,13 +79,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
70 79
71 def make_permanent(self): 80 def make_permanent(self):
72 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 81 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
73 self.temp_token_ids = torch.tensor([]) 82 self.temp_token_ids = torch.tensor([], dtype=torch.long)
74 83
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 84 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list): 85 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) 86 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
78 87
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) 88 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
80 89
81 embeds = self.token_embedding(input_ids) 90 embeds = self.token_embedding(input_ids)
82 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 91 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 7e08287..63566e0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer):
44 44
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) 45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids)
46 46
47 def encode(self, *args, vector_shuffle=True, **kwargs): 47 def expand_id(self, id: int, vector_shuffle=True):
48 ids = super().encode(*args, **kwargs) 48 if id in self.token_map:
49 new_ids = [] 49 tokens = self.token_map[id]
50 50
51 for id in ids: 51 if vector_shuffle:
52 if id in self.token_map: 52 tokens = copy.copy(tokens)
53 tokens = self.token_map[id] 53 np.random.shuffle(tokens)
54 54
55 if vector_shuffle: 55 return tokens
56 tokens = copy.copy(tokens) 56 else:
57 np.random.shuffle(tokens) 57 return [id]
58 58
59 new_ids = new_ids + self.token_map[id] 59 def expand_ids(self, ids: list[int], vector_shuffle=True):
60 else: 60 return [
61 new_ids.append(id) 61 new_id
62 for id in ids
63 for new_id in self.expand_id(id, vector_shuffle)
64 ]
62 65
63 return new_ids 66 def _call_one(self, text, *args, vector_shuffle=True, **kwargs):
67 result = super()._call_one(text, *args, **kwargs)
68
69 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
70
71 if is_batched:
72 result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids]
73 else:
74 result.input_ids = self.expand_ids(result.input_ids, vector_shuffle)
75
76 return result