summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
commit358874cd2c49cb55676af86d2950b86d9ccb023a (patch)
tree786239d45944bab1f2af8e24165fc5d5054617f3 /models
parentVarious updated; shuffle prompt content during training (diff)
downloadtextual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip
Support attention_mask of text encoder
Diffstat (limited to 'models')
-rw-r--r--models/clip/prompt.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 6b6b7e9..9b427a0 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -22,11 +22,15 @@ class PromptProcessor():
22 padding=True, 22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length, 23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt" 24 return_tensors="pt"
25 ).input_ids 25 )
26 26
27 def get_embeddings(self, input_ids: torch.IntTensor): 27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29
29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
30 text_embeddings = self.text_encoder(input_ids)[0] 31 if attention_mask is not None:
32 attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0]
31 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) 35 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2]))
32 return text_embeddings 36 return text_embeddings