diff options
Diffstat (limited to 'models/clip/prompt.py')
-rw-r--r-- | models/clip/prompt.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -1,4 +1,4 @@ | |||
1 | from typing import List, Union | 1 | from typing import Union |
2 | 2 | ||
3 | import torch | 3 | import torch |
4 | 4 | ||
@@ -10,13 +10,13 @@ class PromptProcessor(): | |||
10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
12 | 12 | ||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
14 | return self.tokenizer( | 14 | return self.tokenizer( |
15 | prompt, | 15 | prompt, |
16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
17 | ).input_ids | 17 | ).input_ids |
18 | 18 | ||
19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
22 | padding=True, | 22 | padding=True, |