From 64c79cc3e7fad49131f90fbb0648b6d5587563e5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 08:43:34 +0100 Subject: Various updated; shuffle prompt content during training --- infer.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 30e11cf..e3fa9e5 100644 --- a/infer.py +++ b/infer.py @@ -8,7 +8,18 @@ from pathlib import Path import torch import json from PIL import Image -from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + PNDMScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + DDIMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler +) from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -29,7 +40,7 @@ default_args = { default_cmds = { - "scheduler": "dpmpp", + "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, "image": None, @@ -38,7 +49,7 @@ default_cmds = { "height": 512, "batch_size": 1, "batch_num": 1, - "steps": 50, + "steps": 30, "guidance_scale": 7.0, "seed": None, "config": None, @@ -90,7 +101,7 @@ def create_cmd_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], + choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], ) parser.add_argument( "--prompt", @@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "ddim": pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "dpmpp": + elif args.scheduler == "dpmsm": pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "dpmss": + pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "euler_a": pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "kdpm2": + pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "kdpm2_a": + pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): -- cgit v1.2.3-54-g00ecf