diff options
Diffstat (limited to 'models/clip/util.py')
| -rw-r--r-- | models/clip/util.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
| @@ -5,27 +5,32 @@ import torch | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): | 8 | def unify_input_ids( |
| 9 | tokenizer: CLIPTokenizer, | ||
| 10 | input_ids: list[list[int]], | ||
| 11 | max_length: Optional[int] = None, | ||
| 12 | ): | ||
| 9 | if max_length is None: | 13 | if max_length is None: |
| 10 | return tokenizer.pad( | 14 | return tokenizer.pad( |
| 11 | {"input_ids": input_ids}, | 15 | {"input_ids": input_ids}, |
| 12 | padding=True, | 16 | padding=True, |
| 13 | pad_to_multiple_of=tokenizer.model_max_length, | 17 | pad_to_multiple_of=tokenizer.model_max_length, |
| 14 | return_tensors="pt" | 18 | return_tensors="pt", |
| 15 | ) | 19 | ) |
| 16 | else: | 20 | else: |
| 17 | return tokenizer.pad( | 21 | return tokenizer.pad( |
| 18 | {"input_ids": input_ids}, | 22 | {"input_ids": input_ids}, |
| 19 | padding="max_length", | 23 | padding="max_length", |
| 20 | max_length=max_length, | 24 | max_length=max_length, |
| 21 | return_tensors="pt" | 25 | return_tensors="pt", |
| 22 | ) | 26 | ) |
| 23 | 27 | ||
| 28 | |||
| 24 | def get_extended_embeddings( | 29 | def get_extended_embeddings( |
| 25 | text_encoder: CLIPTextModel, | 30 | text_encoder: CLIPTextModel, |
| 26 | input_ids: torch.LongTensor, | 31 | input_ids: torch.LongTensor, |
| 27 | position_ids: Optional[torch.LongTensor] = None, | 32 | position_ids: Optional[torch.LongTensor] = None, |
| 28 | attention_mask=None | 33 | attention_mask=None, |
| 29 | ): | 34 | ): |
| 30 | model_max_length = text_encoder.config.max_position_embeddings | 35 | model_max_length = text_encoder.config.max_position_embeddings |
| 31 | prompts = input_ids.shape[0] | 36 | prompts = input_ids.shape[0] |
| @@ -36,6 +41,8 @@ def get_extended_embeddings( | |||
| 36 | if attention_mask is not None: | 41 | if attention_mask is not None: |
| 37 | attention_mask = attention_mask.view((-1, model_max_length)) | 42 | attention_mask = attention_mask.view((-1, model_max_length)) |
| 38 | 43 | ||
| 39 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 44 | text_embeddings = text_encoder( |
| 45 | input_ids, position_ids=position_ids, attention_mask=attention_mask | ||
| 46 | )[0] | ||
| 40 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 47 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 41 | return text_embeddings | 48 | return text_embeddings |
