summaryrefslogtreecommitdiffstats
path: root/models/clip/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/prompt.py')
-rw-r--r--models/clip/prompt.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index da33ecf..9da3955 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import List, Union 1from typing import Union
2 2
3import torch 3import torch
4 4
@@ -10,13 +10,13 @@ class PromptProcessor():
10 self.tokenizer = tokenizer 10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder 11 self.text_encoder = text_encoder
12 12
13 def get_input_ids(self, prompt: Union[str, List[str]]): 13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer( 14 return self.tokenizer(
15 prompt, 15 prompt,
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: List[int]): 19 def unify_input_ids(self, input_ids: list[int]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,