summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/util.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/util.py b/models/clip/util.py
index 8de8c19..883de6a 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -23,11 +23,11 @@ def get_extended_embeddings(
23 model_max_length = text_encoder.config.max_position_embeddings 23 model_max_length = text_encoder.config.max_position_embeddings
24 prompts = input_ids.shape[0] 24 prompts = input_ids.shape[0]
25 25
26 input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) 26 input_ids = input_ids.view((-1, model_max_length))
27 if position_ids is not None: 27 if position_ids is not None:
28 position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) 28 position_ids = position_ids.view((-1, model_max_length))
29 if attention_mask is not None: 29 if attention_mask is not None:
30 attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) 30 attention_mask = attention_mask.view((-1, model_max_length))
31 31
32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] 32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))