diff options
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9c3a56b..1280ebd 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
72 | 72 | ||
73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
75 | dtype=self.temp_token_embedding.weight.dtype) | 75 | device=self.temp_token_embedding.weight.device, |
76 | dtype=self.temp_token_embedding.weight.dtype, | ||
77 | ) | ||
76 | 78 | ||
77 | def load_embed(self, input_ids: list[int], filename: Path): | 79 | def load_embed(self, input_ids: list[int], filename: Path): |
78 | with safe_open(filename, framework="pt", device="cpu") as file: | 80 | with safe_open(filename, framework="pt", device="cpu") as file: |