diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:46 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:46 +0100 |
| commit | c779b309a270dba2ac3bebad5a751b3963951ee2 (patch) | |
| tree | eff73cf66687ac2a3aaafea0f67cf67be075e872 | |
| parent | Add contextmanager to EMAModel to apply weights temporarily (diff) | |
| download | textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.gz textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.bz2 textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.zip | |
Add prompt template argument to inference
| -rw-r--r-- | infer.py | 7 |
1 files changed, 7 insertions, 0 deletions
| @@ -48,6 +48,7 @@ default_args = { | |||
| 48 | default_cmds = { | 48 | default_cmds = { |
| 49 | "project": "", | 49 | "project": "", |
| 50 | "scheduler": "dpmsm", | 50 | "scheduler": "dpmsm", |
| 51 | "template": "{}", | ||
| 51 | "prompt": None, | 52 | "prompt": None, |
| 52 | "negative_prompt": None, | 53 | "negative_prompt": None, |
| 53 | "shuffle": False, | 54 | "shuffle": False, |
| @@ -118,6 +119,10 @@ def create_cmd_parser(): | |||
| 118 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 119 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
| 119 | ) | 120 | ) |
| 120 | parser.add_argument( | 121 | parser.add_argument( |
| 122 | "--template", | ||
| 123 | type=str, | ||
| 124 | ) | ||
| 125 | parser.add_argument( | ||
| 121 | "--prompt", | 126 | "--prompt", |
| 122 | type=str, | 127 | type=str, |
| 123 | nargs="+", | 128 | nargs="+", |
| @@ -243,6 +248,8 @@ def generate(output_dir: Path, pipeline, args): | |||
| 243 | args.batch_size = 1 | 248 | args.batch_size = 1 |
| 244 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | 249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] |
| 245 | 250 | ||
| 251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | ||
| 252 | |||
| 246 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 247 | image_dir = [] | 254 | image_dir = [] |
| 248 | 255 | ||
