diff options
Diffstat (limited to 'models/clip/prompt.py')
| -rw-r--r-- | models/clip/prompt.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import List, Union | 1 | from typing import Union |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
| @@ -10,13 +10,13 @@ class PromptProcessor(): | |||
| 10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
| 11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
| 12 | 12 | ||
| 13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
| 14 | return self.tokenizer( | 14 | return self.tokenizer( |
| 15 | prompt, | 15 | prompt, |
| 16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
| 17 | ).input_ids | 17 | ).input_ids |
| 18 | 18 | ||
| 19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
| 20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
| 21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
| 22 | padding=True, | 22 | padding=True, |
