summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index fb639f1..384c795 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
88 def save_embed(self, input_ids: list[int], filename: Path): 88 def save_embed(self, input_ids: list[int], filename: Path):
89 save_file({"embed": self.get_embed(input_ids)}, filename) 89 save_file({"embed": self.get_embed(input_ids)}, filename)
90 90
91 def make_permanent(self): 91 def persist(self):
92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
93 self.temp_token_ids = torch.tensor([], dtype=torch.long) 93 self.temp_token_ids = torch.tensor([], dtype=torch.long)
94 94
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 if isinstance(input_ids, list): 96 if isinstance(input_ids, list):
97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
98 98
99 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
100
101 embeds = self.token_embedding(input_ids) 99 embeds = self.token_embedding(input_ids)
100
101 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
102 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 102 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
103 103
104 return embeds 104 return embeds