summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
commit2e654c017780d37f3304436e2feb84b619f1c023 (patch)
tree8a248fe17c3512110de9fcfed7f7bfd708b3b8da /models/sparse.py
parentTI: Delta learning (diff)
downloadtextual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.gz
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.bz2
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.zip
Improved sparse embeddings
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/models/sparse.py b/models/sparse.py
new file mode 100644
index 0000000..0b15454
--- /dev/null
+++ b/models/sparse.py
@@ -0,0 +1,57 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, device=None, dtype=torch.float32):
9 super().__init__()
10
11 self.embedding_dim = embedding_dim
12 self.dtype = dtype
13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long)
15
16 def forward(self, input_ids: Optional[torch.LongTensor] = None):
17 if input_ids is None:
18 input_ids = torch.arange(self.mapping.shape[0])
19
20 ids = self.mapping[input_ids.to(self.mapping.device)]
21 mask = ~(ids == -1)
22
23 if torch.all(~mask):
24 embs = None
25 else:
26 embs = torch.stack([self.params[id] for id in ids[mask]])
27
28 return embs, mask
29
30 def resize(self, new_num_embeddings: int):
31 old_num_embeddings = self.mapping.shape[0]
32 n = min(old_num_embeddings, new_num_embeddings)
33
34 new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
35 new_mapping[:n] = self.mapping[:n]
36
37 self.mapping = new_mapping
38
39 def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
40 if len(input_ids.shape) != 0:
41 if tensor is not None:
42 return [self.set(id, t) for id, t in zip(input_ids, tensor)]
43 else:
44 return [self.set(id) for id in input_ids]
45
46 id = self.mapping[input_ids]
47
48 if id == -1:
49 id = len(self.params)
50 self.mapping[input_ids] = id
51 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
52
53 self.params[id] = tensor if tensor is not None else torch.zeros(
54 self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
55
56 def unset(self, input_ids: torch.LongTensor):
57 self.mapping[input_ids] = -1