summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
commit9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch)
tree19bd8802b6cfd941797beabfc0bb2595ffb00b5f /models/sparse.py
parentFix TI (diff)
downloadtextual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip
Update
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/models/sparse.py b/models/sparse.py
index 8910316..d706db5 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module):
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 15
16 def forward(self, input_ids: torch.LongTensor): 16 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]