summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
commit12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch)
treeb0fcf8ad1d26c40d784ddc154622f6d01ecac082 /models/clip
parentNew loss scaling (diff)
downloadtextual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip
Update
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 8c3c6d4..afb7430 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
79 def save_embed(self, input_ids: list[int], filename: Path): 79 def save_embed(self, input_ids: list[int], filename: Path):
80 save_file({"embed": self.get_embed(input_ids)}, filename) 80 save_file({"embed": self.get_embed(input_ids)}, filename)
81 81
82 def persist(self): 82 def persist(self, clear=False):
83 self.token_embedding.persist() 83 self.token_embedding.persist(clear)
84 84
85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
86 if isinstance(input_ids, list): 86 if isinstance(input_ids, list):