summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index dc4708a..9be8256 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 97 save_file({"embed": self.get_embed(input_ids)}, filename)
98 98
99 def persist(self): 99 def persist(self):
100 input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) 100 input_ids = torch.arange(
101 self.token_embedding.num_embeddings,
102 device=self.token_override_embedding.mapping.device
103 )
101 embs, mask = self.token_override_embedding(input_ids) 104 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 105 if embs is not None:
103 input_ids = input_ids[mask] 106 input_ids = input_ids[mask]