summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
commit1a0161f345191d78a19eec829f9d73b2c2c72f94 (patch)
tree6d7bcc67672ebf26454b3254b4bd9d5ec7e64a16 /models/clip
parentFix (diff)
downloadtextual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.gz
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.bz2
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.zip
Update
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 840f8ae..4444cf9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
75 75
76 def persist(self): 76 def persist(self):
77 self.token_embedding.eval() 77 self.token_embedding.persist()
78 self.token_embedding.merged = False
79 78
80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 79 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
81 if isinstance(input_ids, list): 80 if isinstance(input_ids, list):