diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
| commit | 0767c7bc82645186159965c2a6be4278e33c6721 (patch) | |
| tree | a136470ab85dbb99ab51d9be4a7831fe21612ab3 /models/clip | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2 textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip | |
Update
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/util.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/util.py b/models/clip/util.py index 8de8c19..883de6a 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
| @@ -23,11 +23,11 @@ def get_extended_embeddings( | |||
| 23 | model_max_length = text_encoder.config.max_position_embeddings | 23 | model_max_length = text_encoder.config.max_position_embeddings |
| 24 | prompts = input_ids.shape[0] | 24 | prompts = input_ids.shape[0] |
| 25 | 25 | ||
| 26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | 26 | input_ids = input_ids.view((-1, model_max_length)) |
| 27 | if position_ids is not None: | 27 | if position_ids is not None: |
| 28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | 28 | position_ids = position_ids.view((-1, model_max_length)) |
| 29 | if attention_mask is not None: | 29 | if attention_mask is not None: |
| 30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | 30 | attention_mask = attention_mask.view((-1, model_max_length)) |
| 31 | 31 | ||
| 32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] |
| 33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
