From 1aace3e44dae0489130039714f67d980628c92ec Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 12:59:08 +0200 Subject: Avoid model recompilation due to varying prompt lengths --- models/clip/util.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'models/clip') diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -5,14 +5,21 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel -def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): - return tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - pad_to_multiple_of=tokenizer.model_max_length, - return_tensors="pt" - ) - +def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): + if max_length is None: + return tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=tokenizer.model_max_length, + return_tensors="pt" + ) + else: + return tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=max_length, + return_tensors="pt" + ) def get_extended_embeddings( text_encoder: CLIPTextModel, -- cgit v1.2.3-70-g09d2