summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-26 16:34:31 +0200
committerVolpeon <git@volpeon.ink>2023-03-26 16:34:31 +0200
commite87fd8c55397db2bdf5177c42d2013037a9b9896 (patch)
tree264bd8b3d60c140d3e65abac0ed6d031403c4ccf /models/clip/embeddings.py
parentImproved TI embeddings (diff)
downloadtextual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.tar.gz
textual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.tar.bz2
textual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.zip
Fix TI embeddings init
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 8d01867..870ee49 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -77,7 +77,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
77 token_ids = torch.tensor(token_ids, dtype=torch.long) 77 token_ids = torch.tensor(token_ids, dtype=torch.long)
78 78
79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 79 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
80 mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) 80 mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1)
81 self.temp_token_embedding.weight.data[mask] = initializer.to( 81 self.temp_token_embedding.weight.data[mask] = initializer.to(
82 device=self.temp_token_embedding.weight.device, 82 device=self.temp_token_embedding.weight.device,
83 dtype=self.temp_token_embedding.weight.dtype, 83 dtype=self.temp_token_embedding.weight.dtype,