summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
commit30b557c8e1f03b4748ac3efca599ff51d66561cb (patch)
tree59aaacde83a7a44dc267c64455f6dc2cfb90c01f /models/sparse.py
parentImproved sparse embeddings (diff)
downloadtextual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip
TI: Bring back old embedding decay
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/models/sparse.py b/models/sparse.py
index 0b15454..8910316 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module):
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.mapping = torch.zeros(0, device=device, dtype=torch.long)
15 15
16 def forward(self, input_ids: Optional[torch.LongTensor] = None): 16 def forward(self, input_ids: torch.LongTensor):
17 if input_ids is None:
18 input_ids = torch.arange(self.mapping.shape[0])
19
20 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]
21 mask = ~(ids == -1) 18 mask = ~(ids == -1)
22 19
@@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module):
43 else: 40 else:
44 return [self.set(id) for id in input_ids] 41 return [self.set(id) for id in input_ids]
45 42
43 if tensor is None:
44 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
45
46 if tensor.shape[-1] != self.embedding_dim:
47 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
48
46 id = self.mapping[input_ids] 49 id = self.mapping[input_ids]
47 50
48 if id == -1: 51 if id == -1:
@@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module):
50 self.mapping[input_ids] = id 53 self.mapping[input_ids] = id
51 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) 54 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
52 55
53 self.params[id] = tensor if tensor is not None else torch.zeros( 56 self.params[id] = tensor
54 self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
55 57
56 def unset(self, input_ids: torch.LongTensor): 58 def unset(self, input_ids: torch.LongTensor):
57 self.mapping[input_ids] = -1 59 self.mapping[input_ids] = -1