diff options
author | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
commit | 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch) | |
tree | b0fcf8ad1d26c40d784ddc154622f6d01ecac082 /models/sparse.py | |
parent | New loss scaling (diff) | |
download | textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2 textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip |
Update
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
89 | 89 | ||
90 | return weights | 90 | return weights |
91 | 91 | ||
92 | def persist(self): | 92 | def persist(self, clear=False): |
93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
94 | self.trainable_ids[:] = -1 | 94 | |
95 | self.trainable = nn.ParameterList() | 95 | if clear: |
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | else: | ||
99 | for param in self.trainable: | ||
100 | param.zero_() | ||
96 | 101 | ||
97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |