diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 4 | ||||
-rw-r--r-- | models/sparse.py | 11 |
2 files changed, 10 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
81 | 81 | ||
82 | def persist(self): | 82 | def persist(self, clear=False): |
83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
84 | 84 | ||
85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
89 | 89 | ||
90 | return weights | 90 | return weights |
91 | 91 | ||
92 | def persist(self): | 92 | def persist(self, clear=False): |
93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
94 | self.trainable_ids[:] = -1 | 94 | |
95 | self.trainable = nn.ParameterList() | 95 | if clear: |
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | else: | ||
99 | for param in self.trainable: | ||
100 | param.zero_() | ||
96 | 101 | ||
97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |