diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
commit | 358874cd2c49cb55676af86d2950b86d9ccb023a (patch) | |
tree | 786239d45944bab1f2af8e24165fc5d5054617f3 /models/clip | |
parent | Various updated; shuffle prompt content during training (diff) | |
download | textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2 textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip |
Support attention_mask of text encoder
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/prompt.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 6b6b7e9..9b427a0 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -22,11 +22,15 @@ class PromptProcessor(): | |||
22 | padding=True, | 22 | padding=True, |
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | 23 | pad_to_multiple_of=self.tokenizer.model_max_length, |
24 | return_tensors="pt" | 24 | return_tensors="pt" |
25 | ).input_ids | 25 | ) |
26 | 26 | ||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
29 | |||
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
30 | text_embeddings = self.text_encoder(input_ids)[0] | 31 | if attention_mask is not None: |
32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
33 | |||
34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | ||
31 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) |
32 | return text_embeddings | 36 | return text_embeddings |