summaryrefslogtreecommitdiffstats
path: root/models/clip/prompt.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
commit6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch)
tree52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /models/clip/prompt.py
parentMisc improvements (diff)
downloadtextual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip
Added multi-vector embeddings
Diffstat (limited to 'models/clip/prompt.py')
-rw-r--r--models/clip/prompt.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index da33ecf..9da3955 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import List, Union 1from typing import Union
2 2
3import torch 3import torch
4 4
@@ -10,13 +10,13 @@ class PromptProcessor():
10 self.tokenizer = tokenizer 10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder 11 self.text_encoder = text_encoder
12 12
13 def get_input_ids(self, prompt: Union[str, List[str]]): 13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer( 14 return self.tokenizer(
15 prompt, 15 prompt,
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: List[int]): 19 def unify_input_ids(self, input_ids: list[int]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,