summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-09 16:21:52 +0200
committerVolpeon <git@volpeon.ink>2023-04-09 16:21:52 +0200
commit776213e99da4ec389575e797d93de8d8960fa010 (patch)
treea21a76e32dbacb707c3d251c56e92d618d5e921b /models
parentFix (diff)
downloadtextual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.gz
textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.bz2
textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.zip
Update
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index dc4708a..9be8256 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 97 save_file({"embed": self.get_embed(input_ids)}, filename)
98 98
99 def persist(self): 99 def persist(self):
100 input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) 100 input_ids = torch.arange(
101 self.token_embedding.num_embeddings,
102 device=self.token_override_embedding.mapping.device
103 )
101 embs, mask = self.token_override_embedding(input_ids) 104 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 105 if embs is not None:
103 input_ids = input_ids[mask] 106 input_ids = input_ids[mask]