diff options
-rw-r--r-- | infer.py | 56 |
1 files changed, 37 insertions, 19 deletions
@@ -5,6 +5,7 @@ import sys | |||
5 | import shlex | 5 | import shlex |
6 | import cmd | 6 | import cmd |
7 | from pathlib import Path | 7 | from pathlib import Path |
8 | from typing import Optional | ||
8 | import torch | 9 | import torch |
9 | import json | 10 | import json |
10 | import traceback | 11 | import traceback |
@@ -49,7 +50,8 @@ default_args = { | |||
49 | 50 | ||
50 | default_cmds = { | 51 | default_cmds = { |
51 | "project": "", | 52 | "project": "", |
52 | "scheduler": "dpmsm", | 53 | "scheduler": "unipc", |
54 | "subscheduler": None, | ||
53 | "template": "{}", | 55 | "template": "{}", |
54 | "prompt": None, | 56 | "prompt": None, |
55 | "negative_prompt": None, | 57 | "negative_prompt": None, |
@@ -127,6 +129,12 @@ def create_cmd_parser(): | |||
127 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], | 129 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], |
128 | ) | 130 | ) |
129 | parser.add_argument( | 131 | parser.add_argument( |
132 | "--subscheduler", | ||
133 | type=str, | ||
134 | default=None, | ||
135 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | ||
136 | ) | ||
137 | parser.add_argument( | ||
130 | "--template", | 138 | "--template", |
131 | type=str, | 139 | type=str, |
132 | ) | 140 | ) |
@@ -227,6 +235,33 @@ def load_embeddings(pipeline, embeddings_dir): | |||
227 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 235 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
228 | 236 | ||
229 | 237 | ||
238 | def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): | ||
239 | if scheduler == "plms": | ||
240 | return PNDMScheduler.from_config(config) | ||
241 | elif scheduler == "klms": | ||
242 | return LMSDiscreteScheduler.from_config(config) | ||
243 | elif scheduler == "ddim": | ||
244 | return DDIMScheduler.from_config(config) | ||
245 | elif scheduler == "dpmsm": | ||
246 | return DPMSolverMultistepScheduler.from_config(config) | ||
247 | elif scheduler == "dpmss": | ||
248 | return DPMSolverSinglestepScheduler.from_config(config) | ||
249 | elif scheduler == "euler_a": | ||
250 | return EulerAncestralDiscreteScheduler.from_config(config) | ||
251 | elif scheduler == "kdpm2": | ||
252 | return KDPM2DiscreteScheduler.from_config(config) | ||
253 | elif scheduler == "kdpm2_a": | ||
254 | return KDPM2AncestralDiscreteScheduler.from_config(config) | ||
255 | elif scheduler == "unipc": | ||
256 | if subscheduler is None: | ||
257 | return UniPCMultistepScheduler.from_config(config) | ||
258 | else: | ||
259 | return UniPCMultistepScheduler.from_config( | ||
260 | config, | ||
261 | solver_p=create_scheduler(config, subscheduler), | ||
262 | ) | ||
263 | |||
264 | |||
230 | def create_pipeline(model, dtype): | 265 | def create_pipeline(model, dtype): |
231 | print("Loading Stable Diffusion pipeline...") | 266 | print("Loading Stable Diffusion pipeline...") |
232 | 267 | ||
@@ -295,24 +330,7 @@ def generate(output_dir: Path, pipeline, args): | |||
295 | else: | 330 | else: |
296 | init_image = None | 331 | init_image = None |
297 | 332 | ||
298 | if args.scheduler == "plms": | 333 | pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) |
299 | pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) | ||
300 | elif args.scheduler == "klms": | ||
301 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
302 | elif args.scheduler == "ddim": | ||
303 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||
304 | elif args.scheduler == "dpmsm": | ||
305 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | ||
306 | elif args.scheduler == "dpmss": | ||
307 | pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) | ||
308 | elif args.scheduler == "euler_a": | ||
309 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
310 | elif args.scheduler == "kdpm2": | ||
311 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | ||
312 | elif args.scheduler == "kdpm2_a": | ||
313 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
314 | elif args.scheduler == "unipc": | ||
315 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | ||
316 | 334 | ||
317 | for i in range(args.batch_num): | 335 | for i in range(args.batch_num): |
318 | pipeline.set_progress_bar_config( | 336 | pipeline.set_progress_bar_config( |