summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 16:57:58 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 16:57:58 +0100
commit2847ab264830ab78491a31c168b3a24bfcf66334 (patch)
tree424b877875f01401845aaa03e98e59cb98c7679b /models
parentUpdate (diff)
downloadtextual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.tar.gz
textual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.tar.bz2
textual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.zip
Make prompt processor compatible with any model
Diffstat (limited to 'models')
-rw-r--r--models/clip/prompt.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 259ac44..6b6b7e9 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -27,5 +27,6 @@ class PromptProcessor():
27 def get_embeddings(self, input_ids: torch.IntTensor): 27 def get_embeddings(self, input_ids: torch.IntTensor):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
30 text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) 30 text_embeddings = self.text_encoder(input_ids)[0]
31 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2]))
31 return text_embeddings 32 return text_embeddings