From 33884c491acaa22edea467a514d7c75263a30bb1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Feb 2023 17:57:18 +0100 Subject: Inference script: Better scheduler config --- infer.py | 56 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/infer.py b/infer.py index 80bd208..51cf3a7 100644 --- a/infer.py +++ b/infer.py @@ -5,6 +5,7 @@ import sys import shlex import cmd from pathlib import Path +from typing import Optional import torch import json import traceback @@ -49,7 +50,8 @@ default_args = { default_cmds = { "project": "", - "scheduler": "dpmsm", + "scheduler": "unipc", + "subscheduler": None, "template": "{}", "prompt": None, "negative_prompt": None, @@ -126,6 +128,12 @@ def create_cmd_parser(): type=str, choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], ) + parser.add_argument( + "--subscheduler", + type=str, + default=None, + choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], + ) parser.add_argument( "--template", type=str, @@ -227,6 +235,33 @@ def load_embeddings(pipeline, embeddings_dir): print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") +def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): + if scheduler == "plms": + return PNDMScheduler.from_config(config) + elif scheduler == "klms": + return LMSDiscreteScheduler.from_config(config) + elif scheduler == "ddim": + return DDIMScheduler.from_config(config) + elif scheduler == "dpmsm": + return DPMSolverMultistepScheduler.from_config(config) + elif scheduler == "dpmss": + return DPMSolverSinglestepScheduler.from_config(config) + elif scheduler == "euler_a": + return EulerAncestralDiscreteScheduler.from_config(config) + elif scheduler == "kdpm2": + return KDPM2DiscreteScheduler.from_config(config) + elif scheduler == "kdpm2_a": + return KDPM2AncestralDiscreteScheduler.from_config(config) + elif scheduler == "unipc": + if subscheduler is None: + return UniPCMultistepScheduler.from_config(config) + else: + return UniPCMultistepScheduler.from_config( + config, + solver_p=create_scheduler(config, subscheduler), + ) + + def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") @@ -295,24 +330,7 @@ def generate(output_dir: Path, pipeline, args): else: init_image = None - if args.scheduler == "plms": - pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "klms": - pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "ddim": - pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "dpmsm": - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "dpmss": - pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "euler_a": - pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "kdpm2": - pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "kdpm2_a": - pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "unipc": - pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) for i in range(args.batch_num): pipeline.set_progress_bar_config( -- cgit v1.2.3-70-g09d2