diff options
author | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
commit | 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch) | |
tree | 19bd8802b6cfd941797beabfc0bb2595ffb00b5f /models/clip | |
parent | Fix TI (diff) | |
download | textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2 textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip |
Update
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 63a141f..6fda33c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
96 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
97 | 97 | ||
98 | def persist(self): | 98 | def persist(self): |
99 | input_ids = torch.arange(self.token_embedding.num_embeddings) | 99 | input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) |
100 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
101 | if embs is not None: | 101 | if embs is not None: |
102 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |