summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 840f8ae..4444cf9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
75 75
76 def persist(self): 76 def persist(self):
77 self.token_embedding.eval() 77 self.token_embedding.persist()
78 self.token_embedding.merged = False
79 78
80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 79 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
81 if isinstance(input_ids, list): 80 if isinstance(input_ids, list):