summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py2
-rw-r--r--models/sparse.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 63a141f..6fda33c 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 save_file({"embed": self.get_embed(input_ids)}, filename) 96 save_file({"embed": self.get_embed(input_ids)}, filename)
97 97
98 def persist(self): 98 def persist(self):
99 input_ids = torch.arange(self.token_embedding.num_embeddings) 99 input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device)
100 embs, mask = self.token_override_embedding(input_ids) 100 embs, mask = self.token_override_embedding(input_ids)
101 if embs is not None: 101 if embs is not None:
102 input_ids = input_ids[mask] 102 input_ids = input_ids[mask]
diff --git a/models/sparse.py b/models/sparse.py
index 8910316..d706db5 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module):
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 15
16 def forward(self, input_ids: torch.LongTensor): 16 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]