summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 09:07:46 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 09:07:46 +0100
commitc779b309a270dba2ac3bebad5a751b3963951ee2 (patch)
treeeff73cf66687ac2a3aaafea0f67cf67be075e872
parentAdd contextmanager to EMAModel to apply weights temporarily (diff)
downloadtextual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.gz
textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.tar.bz2
textual-inversion-diff-c779b309a270dba2ac3bebad5a751b3963951ee2.zip
Add prompt template argument to inference
-rw-r--r--infer.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/infer.py b/infer.py
index e31cd88..d3d5f1b 100644
--- a/infer.py
+++ b/infer.py
@@ -48,6 +48,7 @@ default_args = {
48default_cmds = { 48default_cmds = {
49 "project": "", 49 "project": "",
50 "scheduler": "dpmsm", 50 "scheduler": "dpmsm",
51 "template": "{}",
51 "prompt": None, 52 "prompt": None,
52 "negative_prompt": None, 53 "negative_prompt": None,
53 "shuffle": False, 54 "shuffle": False,
@@ -118,6 +119,10 @@ def create_cmd_parser():
118 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 119 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
119 ) 120 )
120 parser.add_argument( 121 parser.add_argument(
122 "--template",
123 type=str,
124 )
125 parser.add_argument(
121 "--prompt", 126 "--prompt",
122 type=str, 127 type=str,
123 nargs="+", 128 nargs="+",
@@ -243,6 +248,8 @@ def generate(output_dir: Path, pipeline, args):
243 args.batch_size = 1 248 args.batch_size = 1
244 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] 249 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt]
245 250
251 args.prompt = [args.template.format(prompt) for prompt in args.prompt]
252
246 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 253 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
247 image_dir = [] 254 image_dir = []
248 255