From 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 21:00:29 +0200 Subject: Update --- models/clip/embeddings.py | 4 ++-- models/sparse.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) - def persist(self): - self.token_embedding.persist() + def persist(self, clear=False): + self.token_embedding.persist(clear) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): return weights - def persist(self): + def persist(self, clear=False): self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.trainable_ids[:] = -1 - self.trainable = nn.ParameterList() + + if clear: + self.trainable_ids[:] = -1 + self.trainable = nn.ParameterList() + else: + for param in self.trainable: + param.zero_() def reset_parameters(self): nn.Embedding.reset_parameters(self) -- cgit v1.2.3-54-g00ecf