summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
commitee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch)
tree20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /models/clip
parentImproved Textual Inversion: Completely exclude untrained embeddings from trai... (diff)
downloadtextual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip
Fixed Textual Inversion
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/prompt.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 9b427a0..da33ecf 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -27,10 +27,10 @@ class PromptProcessor():
27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): 27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 29
30 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if attention_mask is not None: 31 if attention_mask is not None:
32 attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 32 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 33
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] 34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0]
35 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) 35 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
36 return text_embeddings 36 return text_embeddings