diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
| 80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 81 | 81 | ||
| 82 | def persist(self): | 82 | def persist(self, clear=False): |
| 83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
| 84 | 84 | ||
| 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
