diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/util.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/util.py b/models/clip/util.py index 8de8c19..883de6a 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -23,11 +23,11 @@ def get_extended_embeddings( | |||
23 | model_max_length = text_encoder.config.max_position_embeddings | 23 | model_max_length = text_encoder.config.max_position_embeddings |
24 | prompts = input_ids.shape[0] | 24 | prompts = input_ids.shape[0] |
25 | 25 | ||
26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | 26 | input_ids = input_ids.view((-1, model_max_length)) |
27 | if position_ids is not None: | 27 | if position_ids is not None: |
28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | 28 | position_ids = position_ids.view((-1, model_max_length)) |
29 | if attention_mask is not None: | 29 | if attention_mask is not None: |
30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | 30 | attention_mask = attention_mask.view((-1, model_max_length)) |
31 | 31 | ||
32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] |
33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |