summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py33
1 files changed, 9 insertions, 24 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index e8cc865..4166dc6 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
38 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
39 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
40 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.init_temp_embeddings()
42 41
43 def init_temp_embeddings(self):
44 self.temp_token_embedding = nn.Embedding( 42 self.temp_token_embedding = nn.Embedding(
45 0, 43 self.token_embedding.num_embeddings,
46 self.token_embedding.embedding_dim, 44 self.token_embedding.embedding_dim,
47 device=self.token_embedding.weight.device, 45 device=self.token_embedding.weight.device,
48 dtype=self.token_embedding.weight.dtype 46 dtype=self.token_embedding.weight.dtype
49 ) 47 )
48 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
50 self.temp_token_ids = torch.tensor([], dtype=torch.long) 49 self.temp_token_ids = torch.tensor([], dtype=torch.long)
51 50
52 def resize(self, size: int): 51 def resize(self, size: int):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -74,16 +74,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 ) 74 )
75 75
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 76 token_ids = torch.tensor(token_ids, dtype=torch.long)
77 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
78
79 self.temp_token_embedding = resize_embedding(
80 self.temp_token_embedding,
81 self.temp_token_ids.shape[0],
82 self.initializer_factor
83 )
84 77
85 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) 78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
86 self.temp_token_embedding.weight.data[mask] = initializer 79 self.temp_token_embedding.weight.data[token_ids] = initializer
87 self.token_embedding.weight.data[token_ids] = initializer 80 self.token_embedding.weight.data[token_ids] = initializer
88 81
89 def load_embed(self, input_ids: list[int], filename: Path): 82 def load_embed(self, input_ids: list[int], filename: Path):
@@ -94,25 +87,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
94 save_file({"embed": self.get_embed(input_ids)}, filename) 87 save_file({"embed": self.get_embed(input_ids)}, filename)
95 88
96 def persist(self): 89 def persist(self):
97 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] 90 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
98 self.init_temp_embeddings() 91 self.temp_token_ids = torch.tensor([], dtype=torch.long)
99 92
100 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 93 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
101 if isinstance(input_ids, list): 94 if isinstance(input_ids, list):
102 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 95 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
103 96
104 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
105
106 embeds = self.token_embedding(input_ids) 97 embeds = self.token_embedding(input_ids)
107 98
108 embeds_mask = torch.isin(input_ids, all_temp_token_ids) 99 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
109 temp_token_ids = input_ids[embeds_mask] 100 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
110
111 temp_token_ids = temp_token_ids.unsqueeze(1)
112 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
113 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
114
115 embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
116 101
117 return embeds 102 return embeds
118 103