summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/util.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/models/clip/util.py b/models/clip/util.py
index 883de6a..f94fbc7 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -5,14 +5,21 @@ import torch
5from transformers import CLIPTokenizer, CLIPTextModel 5from transformers import CLIPTokenizer, CLIPTextModel
6 6
7 7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): 8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None):
9 return tokenizer.pad( 9 if max_length is None:
10 {"input_ids": input_ids}, 10 return tokenizer.pad(
11 padding=True, 11 {"input_ids": input_ids},
12 pad_to_multiple_of=tokenizer.model_max_length, 12 padding=True,
13 return_tensors="pt" 13 pad_to_multiple_of=tokenizer.model_max_length,
14 ) 14 return_tensors="pt"
15 15 )
16 else:
17 return tokenizer.pad(
18 {"input_ids": input_ids},
19 padding="max_length",
20 max_length=max_length,
21 return_tensors="pt"
22 )
16 23
17def get_extended_embeddings( 24def get_extended_embeddings(
18 text_encoder: CLIPTextModel, 25 text_encoder: CLIPTextModel,