diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 13:09:04 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 13:09:04 +0100 |
| commit | 8c068963d4b67c6b894e720288e5863dade8d6e6 (patch) | |
| tree | 823bf9852244e5adfe6a4f6fe5fcd87e8441e685 /models/clip | |
| parent | Added multi-vector embeddings (diff) | |
| download | textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.gz textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.tar.bz2 textual-inversion-diff-8c068963d4b67c6b894e720288e5863dade8d6e6.zip | |
Fixes
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7d63ffb..f82873e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -74,7 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | 74 | ||
| 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 76 | if isinstance(input_ids, list): | 76 | if isinstance(input_ids, list): |
| 77 | input_ids = torch.tensor(input_ids) | 77 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) |
| 78 | 78 | ||
| 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) |
| 80 | 80 | ||
