From a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 22:06:05 +0100 Subject: Update --- models/clip/embeddings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9c3a56b..1280ebd 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) self.temp_token_embedding.weight.data[token_ids] = initializer.to( - dtype=self.temp_token_embedding.weight.dtype) + device=self.temp_token_embedding.weight.device, + dtype=self.temp_token_embedding.weight.dtype, + ) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: -- cgit v1.2.3-70-g09d2