diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-09 16:21:52 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-09 16:21:52 +0200 |
| commit | 776213e99da4ec389575e797d93de8d8960fa010 (patch) | |
| tree | a21a76e32dbacb707c3d251c56e92d618d5e921b /models | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.gz textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.bz2 textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.zip | |
Update
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index dc4708a..9be8256 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 97 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 98 | 98 | ||
| 99 | def persist(self): | 99 | def persist(self): |
| 100 | input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) | 100 | input_ids = torch.arange( |
| 101 | self.token_embedding.num_embeddings, | ||
| 102 | device=self.token_override_embedding.mapping.device | ||
| 103 | ) | ||
| 101 | embs, mask = self.token_override_embedding(input_ids) | 104 | embs, mask = self.token_override_embedding(input_ids) |
| 102 | if embs is not None: | 105 | if embs is not None: |
| 103 | input_ids = input_ids[mask] | 106 | input_ids = input_ids[mask] |
