summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-23 11:07:57 +0100
committerVolpeon <git@volpeon.ink>2023-03-23 11:07:57 +0100
commit0767c7bc82645186159965c2a6be4278e33c6721 (patch)
treea136470ab85dbb99ab51d9be4a7831fe21612ab3 /models
parentFix (diff)
downloadtextual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz
textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2
textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip
Update
Diffstat (limited to 'models')
-rw-r--r--models/clip/util.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/util.py b/models/clip/util.py
index 8de8c19..883de6a 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -23,11 +23,11 @@ def get_extended_embeddings(
23 model_max_length = text_encoder.config.max_position_embeddings 23 model_max_length = text_encoder.config.max_position_embeddings
24 prompts = input_ids.shape[0] 24 prompts = input_ids.shape[0]
25 25
26 input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) 26 input_ids = input_ids.view((-1, model_max_length))
27 if position_ids is not None: 27 if position_ids is not None:
28 position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) 28 position_ids = position_ids.view((-1, model_max_length))
29 if attention_mask is not None: 29 if attention_mask is not None:
30 attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) 30 attention_mask = attention_mask.view((-1, model_max_length))
31 31
32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] 32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))