From 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 21:00:29 +0200 Subject: Update --- models/clip/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) - def persist(self): - self.token_embedding.persist() + def persist(self, clear=False): + self.token_embedding.persist(clear) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): -- cgit v1.2.3-70-g09d2