diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 27 |
1 files changed, 22 insertions, 5 deletions
@@ -8,7 +8,18 @@ from pathlib import Path | |||
8 | import torch | 8 | import torch |
9 | import json | 9 | import json |
10 | from PIL import Image | 10 | from PIL import Image |
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler | 11 | from diffusers import ( |
12 | AutoencoderKL, | ||
13 | UNet2DConditionModel, | ||
14 | PNDMScheduler, | ||
15 | DPMSolverMultistepScheduler, | ||
16 | DPMSolverSinglestepScheduler, | ||
17 | DDIMScheduler, | ||
18 | LMSDiscreteScheduler, | ||
19 | EulerAncestralDiscreteScheduler, | ||
20 | KDPM2DiscreteScheduler, | ||
21 | KDPM2AncestralDiscreteScheduler | ||
22 | ) | ||
12 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
13 | from slugify import slugify | 24 | from slugify import slugify |
14 | 25 | ||
@@ -29,7 +40,7 @@ default_args = { | |||
29 | 40 | ||
30 | 41 | ||
31 | default_cmds = { | 42 | default_cmds = { |
32 | "scheduler": "dpmpp", | 43 | "scheduler": "dpmsm", |
33 | "prompt": None, | 44 | "prompt": None, |
34 | "negative_prompt": None, | 45 | "negative_prompt": None, |
35 | "image": None, | 46 | "image": None, |
@@ -38,7 +49,7 @@ default_cmds = { | |||
38 | "height": 512, | 49 | "height": 512, |
39 | "batch_size": 1, | 50 | "batch_size": 1, |
40 | "batch_num": 1, | 51 | "batch_num": 1, |
41 | "steps": 50, | 52 | "steps": 30, |
42 | "guidance_scale": 7.0, | 53 | "guidance_scale": 7.0, |
43 | "seed": None, | 54 | "seed": None, |
44 | "config": None, | 55 | "config": None, |
@@ -90,7 +101,7 @@ def create_cmd_parser(): | |||
90 | parser.add_argument( | 101 | parser.add_argument( |
91 | "--scheduler", | 102 | "--scheduler", |
92 | type=str, | 103 | type=str, |
93 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | 104 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
94 | ) | 105 | ) |
95 | parser.add_argument( | 106 | parser.add_argument( |
96 | "--prompt", | 107 | "--prompt", |
@@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): | |||
252 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | 263 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) |
253 | elif args.scheduler == "ddim": | 264 | elif args.scheduler == "ddim": |
254 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | 265 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
255 | elif args.scheduler == "dpmpp": | 266 | elif args.scheduler == "dpmsm": |
256 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | 267 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
268 | elif args.scheduler == "dpmss": | ||
269 | pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) | ||
257 | elif args.scheduler == "euler_a": | 270 | elif args.scheduler == "euler_a": |
258 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 271 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
272 | elif args.scheduler == "kdpm2": | ||
273 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | ||
274 | elif args.scheduler == "kdpm2_a": | ||
275 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
259 | 276 | ||
260 | with torch.autocast("cuda"), torch.inference_mode(): | 277 | with torch.autocast("cuda"), torch.inference_mode(): |
261 | for i in range(args.batch_num): | 278 | for i in range(args.batch_num): |