summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/infer.py b/infer.py
index 9bc9efe..9b0ec1f 100644
--- a/infer.py
+++ b/infer.py
@@ -8,11 +8,10 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler
12from transformers import CLIPTextModel, CLIPTokenizer 12from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 13from slugify import slugify
14 14
15from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
17 16
18 17
@@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
21 20
22default_args = { 21default_args = {
23 "model": None, 22 "model": None,
24 "scheduler": "euler_a", 23 "scheduler": "dpmpp",
25 "precision": "fp32", 24 "precision": "fp32",
26 "ti_embeddings_dir": "embeddings_ti", 25 "ti_embeddings_dir": "embeddings_ti",
27 "output_dir": "output/inference", 26 "output_dir": "output/inference",
@@ -65,7 +64,7 @@ def create_args_parser():
65 parser.add_argument( 64 parser.add_argument(
66 "--scheduler", 65 "--scheduler",
67 type=str, 66 type=str,
68 choices=["plms", "ddim", "klms", "euler_a"], 67 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
69 ) 68 )
70 parser.add_argument( 69 parser.add_argument(
71 "--precision", 70 "--precision",
@@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
222 scheduler = DDIMScheduler( 221 scheduler = DDIMScheduler(
223 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 222 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
224 ) 223 )
224 elif scheduler == "dpmpp":
225 scheduler = DPMSolverMultistepScheduler(
226 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
227 )
225 else: 228 else:
226 scheduler = EulerAncestralDiscreteScheduler( 229 scheduler = EulerAncestralDiscreteScheduler(
227 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 230 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args):
282 ).images 285 ).images
283 286
284 for j, image in enumerate(images): 287 for j, image in enumerate(images):
285 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 288 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png"))
289 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
286 290
287 if torch.cuda.is_available(): 291 if torch.cuda.is_available():
288 torch.cuda.empty_cache() 292 torch.cuda.empty_cache()
@@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd):
312 316
313 try: 317 try:
314 args = run_parser(self.parser, default_cmds, elements) 318 args = run_parser(self.parser, default_cmds, elements)
319
320 if len(args.prompt) == 0:
321 print('Try again with a prompt!')
322 return
315 except SystemExit: 323 except SystemExit:
316 self.parser.print_help() 324 self.parser.print_help()
317 except Exception as e: 325 except Exception as e:
318 print(e) 326 print(e)
319 return 327 return
320 328
321 if len(args.prompt) == 0:
322 print('Try again with a prompt!')
323
324 try: 329 try:
325 generate(self.output_dir, self.pipeline, args) 330 generate(self.output_dir, self.pipeline, args)
326 except KeyboardInterrupt: 331 except KeyboardInterrupt: