diff options
author | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-24 21:00:29 +0200 |
commit | 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch) | |
tree | b0fcf8ad1d26c40d784ddc154622f6d01ecac082 /models/clip | |
parent | New loss scaling (diff) | |
download | textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2 textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip |
Update
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
81 | 81 | ||
82 | def persist(self): | 82 | def persist(self, clear=False): |
83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
84 | 84 | ||
85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |