diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
commit | 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch) | |
tree | 52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /models/clip/prompt.py | |
parent | Misc improvements (diff) | |
download | textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2 textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip |
Added multi-vector embeddings
Diffstat (limited to 'models/clip/prompt.py')
-rw-r--r-- | models/clip/prompt.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -1,4 +1,4 @@ | |||
1 | from typing import List, Union | 1 | from typing import Union |
2 | 2 | ||
3 | import torch | 3 | import torch |
4 | 4 | ||
@@ -10,13 +10,13 @@ class PromptProcessor(): | |||
10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
12 | 12 | ||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
14 | return self.tokenizer( | 14 | return self.tokenizer( |
15 | prompt, | 15 | prompt, |
16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
17 | ).input_ids | 17 | ).input_ids |
18 | 18 | ||
19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
22 | padding=True, | 22 | padding=True, |