diff options
author | Volpeon <git@volpeon.ink> | 2023-04-09 11:29:31 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-09 11:29:31 +0200 |
commit | ba9fd1a10746d85d2502c8a79ac49db63d346b04 (patch) | |
tree | 568bf65a0a4dcea2c34de4006b5761d0d6564307 /models/sparse.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.gz textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.bz2 textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.zip |
Update
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py index d706db5..bcb2897 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -5,22 +5,29 @@ import torch.nn as nn | |||
5 | 5 | ||
6 | 6 | ||
7 | class PseudoSparseEmbedding(nn.Module): | 7 | class PseudoSparseEmbedding(nn.Module): |
8 | def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): | 8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): |
9 | super().__init__() | 9 | super().__init__() |
10 | 10 | ||
11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
12 | self.dtype = dtype | 12 | self.dtype = dtype |
13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
14 | |||
15 | if dropout_p > 0.0: | ||
16 | self.dropout = nn.Dropout(p=dropout_p) | ||
17 | else: | ||
18 | self.dropout = lambda x: x | ||
19 | |||
14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | 20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
15 | 21 | ||
16 | def forward(self, input_ids: torch.LongTensor): | 22 | def forward(self, input_ids: torch.LongTensor): |
17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 23 | input_ids = input_ids.to(self.mapping.device) |
24 | ids = self.mapping[input_ids] | ||
18 | mask = ~(ids == -1) | 25 | mask = ~(ids == -1) |
19 | 26 | ||
20 | if torch.all(~mask): | 27 | if torch.all(~mask): |
21 | embs = None | 28 | embs = None |
22 | else: | 29 | else: |
23 | embs = torch.stack([self.params[id] for id in ids[mask]]) | 30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) |
24 | 31 | ||
25 | return embs, mask | 32 | return embs, mask |
26 | 33 | ||