summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
commit12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch)
treeb0fcf8ad1d26c40d784ddc154622f6d01ecac082 /models/sparse.py
parentNew loss scaling (diff)
downloadtextual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip
Update
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py
index e5897c9..55c9837 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding):
89 89
90 return weights 90 return weights
91 91
92 def persist(self): 92 def persist(self, clear=False):
93 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) 93 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
94 self.trainable_ids[:] = -1 94
95 self.trainable = nn.ParameterList() 95 if clear:
96 self.trainable_ids[:] = -1
97 self.trainable = nn.ParameterList()
98 else:
99 for param in self.trainable:
100 param.zero_()
96 101
97 def reset_parameters(self): 102 def reset_parameters(self):
98 nn.Embedding.reset_parameters(self) 103 nn.Embedding.reset_parameters(self)