diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/util.py | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -5,14 +5,21 @@ import torch | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
6 | 6 | ||
7 | 7 | ||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): |
9 | return tokenizer.pad( | 9 | if max_length is None: |
10 | {"input_ids": input_ids}, | 10 | return tokenizer.pad( |
11 | padding=True, | 11 | {"input_ids": input_ids}, |
12 | pad_to_multiple_of=tokenizer.model_max_length, | 12 | padding=True, |
13 | return_tensors="pt" | 13 | pad_to_multiple_of=tokenizer.model_max_length, |
14 | ) | 14 | return_tensors="pt" |
15 | 15 | ) | |
16 | else: | ||
17 | return tokenizer.pad( | ||
18 | {"input_ids": input_ids}, | ||
19 | padding="max_length", | ||
20 | max_length=max_length, | ||
21 | return_tensors="pt" | ||
22 | ) | ||
16 | 23 | ||
17 | def get_extended_embeddings( | 24 | def get_extended_embeddings( |
18 | text_encoder: CLIPTextModel, | 25 | text_encoder: CLIPTextModel, |