diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
| commit | 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch) | |
| tree | 19bd8802b6cfd941797beabfc0bb2595ffb00b5f /models | |
| parent | Fix TI (diff) | |
| download | textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2 textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip | |
Update
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 2 | ||||
| -rw-r--r-- | models/sparse.py | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 63a141f..6fda33c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 96 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 97 | 97 | ||
| 98 | def persist(self): | 98 | def persist(self): |
| 99 | input_ids = torch.arange(self.token_embedding.num_embeddings) | 99 | input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) |
| 100 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
| 101 | if embs is not None: | 101 | if embs is not None: |
| 102 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
diff --git a/models/sparse.py b/models/sparse.py index 8910316..d706db5 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
| 12 | self.dtype = dtype | 12 | self.dtype = dtype |
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
| 15 | 15 | ||
| 16 | def forward(self, input_ids: torch.LongTensor): | 16 | def forward(self, input_ids: torch.LongTensor): |
| 17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
