summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py56
1 files changed, 37 insertions, 19 deletions
diff --git a/infer.py b/infer.py
index 80bd208..51cf3a7 100644
--- a/infer.py
+++ b/infer.py
@@ -5,6 +5,7 @@ import sys
5import shlex 5import shlex
6import cmd 6import cmd
7from pathlib import Path 7from pathlib import Path
8from typing import Optional
8import torch 9import torch
9import json 10import json
10import traceback 11import traceback
@@ -49,7 +50,8 @@ default_args = {
49 50
50default_cmds = { 51default_cmds = {
51 "project": "", 52 "project": "",
52 "scheduler": "dpmsm", 53 "scheduler": "unipc",
54 "subscheduler": None,
53 "template": "{}", 55 "template": "{}",
54 "prompt": None, 56 "prompt": None,
55 "negative_prompt": None, 57 "negative_prompt": None,
@@ -127,6 +129,12 @@ def create_cmd_parser():
127 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], 129 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"],
128 ) 130 )
129 parser.add_argument( 131 parser.add_argument(
132 "--subscheduler",
133 type=str,
134 default=None,
135 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
136 )
137 parser.add_argument(
130 "--template", 138 "--template",
131 type=str, 139 type=str,
132 ) 140 )
@@ -227,6 +235,33 @@ def load_embeddings(pipeline, embeddings_dir):
227 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 235 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
228 236
229 237
238def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None):
239 if scheduler == "plms":
240 return PNDMScheduler.from_config(config)
241 elif scheduler == "klms":
242 return LMSDiscreteScheduler.from_config(config)
243 elif scheduler == "ddim":
244 return DDIMScheduler.from_config(config)
245 elif scheduler == "dpmsm":
246 return DPMSolverMultistepScheduler.from_config(config)
247 elif scheduler == "dpmss":
248 return DPMSolverSinglestepScheduler.from_config(config)
249 elif scheduler == "euler_a":
250 return EulerAncestralDiscreteScheduler.from_config(config)
251 elif scheduler == "kdpm2":
252 return KDPM2DiscreteScheduler.from_config(config)
253 elif scheduler == "kdpm2_a":
254 return KDPM2AncestralDiscreteScheduler.from_config(config)
255 elif scheduler == "unipc":
256 if subscheduler is None:
257 return UniPCMultistepScheduler.from_config(config)
258 else:
259 return UniPCMultistepScheduler.from_config(
260 config,
261 solver_p=create_scheduler(config, subscheduler),
262 )
263
264
230def create_pipeline(model, dtype): 265def create_pipeline(model, dtype):
231 print("Loading Stable Diffusion pipeline...") 266 print("Loading Stable Diffusion pipeline...")
232 267
@@ -295,24 +330,7 @@ def generate(output_dir: Path, pipeline, args):
295 else: 330 else:
296 init_image = None 331 init_image = None
297 332
298 if args.scheduler == "plms": 333 pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler)
299 pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config)
300 elif args.scheduler == "klms":
301 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
302 elif args.scheduler == "ddim":
303 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
304 elif args.scheduler == "dpmsm":
305 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
306 elif args.scheduler == "dpmss":
307 pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)
308 elif args.scheduler == "euler_a":
309 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
310 elif args.scheduler == "kdpm2":
311 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
312 elif args.scheduler == "kdpm2_a":
313 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
314 elif args.scheduler == "unipc":
315 pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
316 334
317 for i in range(args.batch_num): 335 for i in range(args.batch_num):
318 pipeline.set_progress_bar_config( 336 pipeline.set_progress_bar_config(