From 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:38:49 +0200 Subject: Update --- models/clip/embeddings.py | 2 +- models/sparse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 63a141f..6fda33c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange(self.token_embedding.num_embeddings) + input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) embs, mask = self.token_override_embedding(input_ids) if embs is not None: input_ids = input_ids[mask] diff --git a/models/sparse.py b/models/sparse.py index 8910316..d706db5 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module): self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() - self.mapping = torch.zeros(0, device=device, dtype=torch.long) + self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) def forward(self, input_ids: torch.LongTensor): ids = self.mapping[input_ids.to(self.mapping.device)] -- cgit v1.2.3-70-g09d2