diff options
author | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:58 +0100 |
commit | 2847ab264830ab78491a31c168b3a24bfcf66334 (patch) | |
tree | 424b877875f01401845aaa03e98e59cb98c7679b /models | |
parent | Update (diff) | |
download | textual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.tar.gz textual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.tar.bz2 textual-inversion-diff-2847ab264830ab78491a31c168b3a24bfcf66334.zip |
Make prompt processor compatible with any model
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/prompt.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 259ac44..6b6b7e9 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -27,5 +27,6 @@ class PromptProcessor(): | |||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | 27 | def get_embeddings(self, input_ids: torch.IntTensor): |
28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
30 | text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) | 30 | text_embeddings = self.text_encoder(input_ids)[0] |
31 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | ||
31 | return text_embeddings | 32 | return text_embeddings |