From 776213e99da4ec389575e797d93de8d8960fa010 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 16:21:52 +0200 Subject: Update --- models/clip/embeddings.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index dc4708a..9be8256 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) + input_ids = torch.arange( + self.token_embedding.num_embeddings, + device=self.token_override_embedding.mapping.device + ) embs, mask = self.token_override_embedding(input_ids) if embs is not None: input_ids = input_ids[mask] -- cgit v1.2.3-54-g00ecf