diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
commit | 3924055ed24da9b6995303cd36282eb558ba0bf0 (patch) | |
tree | 4fed8dabcde2236e1a1e8f5738b2a0bdcfd4513b /models/sparse.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.gz textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.bz2 textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.zip |
Fix
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..bd45696 --- /dev/null +++ b/models/sparse.py | |||
@@ -0,0 +1,110 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | |||
7 | class SparseEmbedding(nn.Embedding): | ||
8 | def __init__( | ||
9 | self, | ||
10 | num_embeddings: int, | ||
11 | embedding_dim: int, | ||
12 | alpha: int = 1, | ||
13 | dropout: float = 0.0, | ||
14 | **kwargs | ||
15 | ): | ||
16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
17 | |||
18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
19 | |||
20 | self.trainable = nn.ParameterList() | ||
21 | self.scaling = alpha | ||
22 | self.dropout_p = dropout | ||
23 | self.weight.requires_grad = False | ||
24 | |||
25 | if dropout > 0.: | ||
26 | self.dropout = nn.Dropout(p=dropout) | ||
27 | else: | ||
28 | self.dropout = nn.Identity() | ||
29 | |||
30 | self.reset_parameters() | ||
31 | |||
32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
33 | n = min(self.num_embeddings, new_num_embeddings) | ||
34 | |||
35 | new_emb = SparseEmbedding( | ||
36 | new_num_embeddings, | ||
37 | self.embedding_dim, | ||
38 | self.scaling, | ||
39 | self.dropout_p, | ||
40 | device=self.weight.device, | ||
41 | dtype=self.weight.dtype | ||
42 | ) | ||
43 | if initializer_factor is not None: | ||
44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
45 | else: | ||
46 | nn.init.zeros_(new_emb.weight.data) | ||
47 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
48 | for param in self.trainable: | ||
49 | new_emb.trainable.append(param) | ||
50 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
51 | |||
52 | return new_emb | ||
53 | |||
54 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
55 | trainable_ids = self.trainable_ids[input_ids] | ||
56 | new_ids = input_ids[trainable_ids == -1] | ||
57 | |||
58 | if new_ids.shape[0] == 0: | ||
59 | return | ||
60 | |||
61 | n1 = len(self.trainable) | ||
62 | n2 = n1 + new_ids.shape[0] | ||
63 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
64 | for _ in new_ids: | ||
65 | self.trainable.append(self.weight.new_zeros(self.embedding_dim)) | ||
66 | |||
67 | def get_weights(self, input_ids: torch.Tensor): | ||
68 | original_shape = input_ids.shape | ||
69 | |||
70 | if len(input_ids.shape) != 1: | ||
71 | input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) | ||
72 | |||
73 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
74 | |||
75 | trainable_ids = self.trainable_ids[input_ids] | ||
76 | mask = ~(trainable_ids == -1) | ||
77 | elems = [self.trainable[id] for id in trainable_ids[mask]] | ||
78 | |||
79 | if len(elems) != 0: | ||
80 | w = self.dropout(torch.stack(elems)) * self.scaling | ||
81 | weights[mask] = w.to(dtype=weights.dtype) | ||
82 | |||
83 | if len(original_shape) != 1: | ||
84 | weights = weights.view(original_shape[0], original_shape[1], -1) | ||
85 | |||
86 | return weights | ||
87 | |||
88 | def persist(self): | ||
89 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
90 | self.trainable_ids[:] = -1 | ||
91 | self.trainable = nn.ParameterList() | ||
92 | |||
93 | def reset_parameters(self): | ||
94 | nn.Embedding.reset_parameters(self) | ||
95 | if hasattr(self, "trainable"): | ||
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | |||
99 | def train(self, mode: bool = True): | ||
100 | nn.Embedding.train(self, mode) | ||
101 | self.trainable.train(mode) | ||
102 | |||
103 | def eval(self): | ||
104 | nn.Embedding.eval(self) | ||
105 | self.trainable.eval() | ||
106 | |||
107 | def forward(self, input_ids: torch.LongTensor): | ||
108 | result = nn.Embedding.forward(self, input_ids) | ||
109 | result += self.get_weights(input_ids) | ||
110 | return result | ||