summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 7d63ffb..f82873e 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,7 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 74
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list): 76 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids) 77 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device)
78 78
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) 79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device))
80 80