diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
| commit | 3396ca881ed3f3521617cd9024eea56975191d32 (patch) | |
| tree | 3189c3bbe77b211152d11b524d0fe3a7016441ee /models/clip | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.gz textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.bz2 textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.zip | |
Update
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/prompt.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9da3955..a7380be 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import Union | 1 | from typing import Union, Optional |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
| @@ -16,7 +16,7 @@ class PromptProcessor(): | |||
| 16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
| 17 | ).input_ids | 17 | ).input_ids |
| 18 | 18 | ||
| 19 | def unify_input_ids(self, input_ids: list[int]): | 19 | def unify_input_ids(self, input_ids: list[list[int]]): |
| 20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
| 21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
| 22 | padding=True, | 22 | padding=True, |
| @@ -24,13 +24,15 @@ class PromptProcessor(): | |||
| 24 | return_tensors="pt" | 24 | return_tensors="pt" |
| 25 | ) | 25 | ) |
| 26 | 26 | ||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): | 27 | def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): |
| 28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
| 29 | 29 | ||
| 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 31 | if position_ids is not None: | ||
| 32 | position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 31 | if attention_mask is not None: | 33 | if attention_mask is not None: |
| 32 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 34 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 33 | 35 | ||
| 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | 36 | text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] |
| 35 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 37 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 36 | return text_embeddings | 38 | return text_embeddings |
