summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9c3a56b..1280ebd 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
72 72
73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype) 75 device=self.temp_token_embedding.weight.device,
76 dtype=self.temp_token_embedding.weight.dtype,
77 )
76 78
77 def load_embed(self, input_ids: list[int], filename: Path): 79 def load_embed(self, input_ids: list[int], filename: Path):
78 with safe_open(filename, framework="pt", device="cpu") as file: 80 with safe_open(filename, framework="pt", device="cpu") as file: