diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
88 | def save_embed(self, input_ids: list[int], filename: Path): | 88 | def save_embed(self, input_ids: list[int], filename: Path): |
89 | save_file({"embed": self.get_embed(input_ids)}, filename) | 89 | save_file({"embed": self.get_embed(input_ids)}, filename) |
90 | 90 | ||
91 | def make_permanent(self): | 91 | def persist(self): |
92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
94 | 94 | ||
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
96 | if isinstance(input_ids, list): | 96 | if isinstance(input_ids, list): |
97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
98 | 98 | ||
99 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
100 | |||
101 | embeds = self.token_embedding(input_ids) | 99 | embeds = self.token_embedding(input_ids) |
100 | |||
101 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
103 | 103 | ||
104 | return embeds | 104 | return embeds |