diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-16 09:44:12 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-16 09:44:12 +0200 |
| commit | 1a0161f345191d78a19eec829f9d73b2c2c72f94 (patch) | |
| tree | 6d7bcc67672ebf26454b3254b4bd9d5ec7e64a16 /models/clip | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.gz textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.bz2 textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.zip | |
Update
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 840f8ae..4444cf9 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 75 | 75 | ||
| 76 | def persist(self): | 76 | def persist(self): |
| 77 | self.token_embedding.eval() | 77 | self.token_embedding.persist() |
| 78 | self.token_embedding.merged = False | ||
| 79 | 78 | ||
| 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 79 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 81 | if isinstance(input_ids, list): | 80 | if isinstance(input_ids, list): |
