diff options
author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:46 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:46 +0100 |
commit | c779b309a270dba2ac3bebad5a751b3963951ee2 (patch) | |
tree | eff73cf66687ac2a3aaafea0f67cf67be075e872 | |
parent | Add contextmanager to EMAModel to apply weights temporarily (diff) | |
download | textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.gz textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.bz2 textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.zip |
Add prompt template argument to inference
-rw-r--r-- | infer.py | 7 |
1 files changed, 7 insertions, 0 deletions
@@ -48,6 +48,7 @@ default_args = { | |||
48 | default_cmds = { | 48 | default_cmds = { |
49 | "project": "", | 49 | "project": "", |
50 | "scheduler": "dpmsm", | 50 | "scheduler": "dpmsm", |
51 | "template": "{}", | ||
51 | "prompt": None, | 52 | "prompt": None, |
52 | "negative_prompt": None, | 53 | "negative_prompt": None, |
53 | "shuffle": False, | 54 | "shuffle": False, |
@@ -118,6 +119,10 @@ def create_cmd_parser(): | |||
118 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 119 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
119 | ) | 120 | ) |
120 | parser.add_argument( | 121 | parser.add_argument( |
122 | "--template", | ||
123 | type=str, | ||
124 | ) | ||
125 | parser.add_argument( | ||
121 | "--prompt", | 126 | "--prompt", |
122 | type=str, | 127 | type=str, |
123 | nargs="+", | 128 | nargs="+", |
@@ -243,6 +248,8 @@ def generate(output_dir: Path, pipeline, args): | |||
243 | args.batch_size = 1 | 248 | args.batch_size = 1 |
244 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | 249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] |
245 | 250 | ||
251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | ||
252 | |||
246 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
247 | image_dir = [] | 254 | image_dir = [] |
248 | 255 | ||