diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
| commit | ee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch) | |
| tree | 20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /models/clip | |
| parent | Improved Textual Inversion: Completely exclude untrained embeddings from trai... (diff) | |
| download | textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2 textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip | |
Fixed Textual Inversion
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/prompt.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9b427a0..da33ecf 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -27,10 +27,10 @@ class PromptProcessor(): | |||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
| 28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
| 29 | 29 | ||
| 30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 31 | if attention_mask is not None: | 31 | if attention_mask is not None: |
| 32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 32 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 33 | 33 | ||
| 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] |
| 35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 36 | return text_embeddings | 36 | return text_embeddings |
