From 0767c7bc82645186159965c2a6be4278e33c6721 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 11:07:57 +0100 Subject: Update --- models/clip/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'models/clip') diff --git a/models/clip/util.py b/models/clip/util.py index 8de8c19..883de6a 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -23,11 +23,11 @@ def get_extended_embeddings( model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] - input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) + input_ids = input_ids.view((-1, model_max_length)) if position_ids is not None: - position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) + position_ids = position_ids.view((-1, model_max_length)) if attention_mask is not None: - attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) + attention_mask = attention_mask.view((-1, model_max_length)) text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) -- cgit v1.2.3-54-g00ecf