diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 21 |
1 files changed, 13 insertions, 8 deletions
@@ -8,11 +8,10 @@ from pathlib import Path | |||
8 | import torch | 8 | import torch |
9 | import json | 9 | import json |
10 | from PIL import Image | 10 | from PIL import Image |
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler |
12 | from transformers import CLIPTextModel, CLIPTokenizer | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
17 | 16 | ||
18 | 17 | ||
@@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
21 | 20 | ||
22 | default_args = { | 21 | default_args = { |
23 | "model": None, | 22 | "model": None, |
24 | "scheduler": "euler_a", | 23 | "scheduler": "dpmpp", |
25 | "precision": "fp32", | 24 | "precision": "fp32", |
26 | "ti_embeddings_dir": "embeddings_ti", | 25 | "ti_embeddings_dir": "embeddings_ti", |
27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
@@ -65,7 +64,7 @@ def create_args_parser(): | |||
65 | parser.add_argument( | 64 | parser.add_argument( |
66 | "--scheduler", | 65 | "--scheduler", |
67 | type=str, | 66 | type=str, |
68 | choices=["plms", "ddim", "klms", "euler_a"], | 67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], |
69 | ) | 68 | ) |
70 | parser.add_argument( | 69 | parser.add_argument( |
71 | "--precision", | 70 | "--precision", |
@@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
222 | scheduler = DDIMScheduler( | 221 | scheduler = DDIMScheduler( |
223 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
224 | ) | 223 | ) |
224 | elif scheduler == "dpmpp": | ||
225 | scheduler = DPMSolverMultistepScheduler( | ||
226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
227 | ) | ||
225 | else: | 228 | else: |
226 | scheduler = EulerAncestralDiscreteScheduler( | 229 | scheduler = EulerAncestralDiscreteScheduler( |
227 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
@@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args): | |||
282 | ).images | 285 | ).images |
283 | 286 | ||
284 | for j, image in enumerate(images): | 287 | for j, image in enumerate(images): |
285 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 288 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
289 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | ||
286 | 290 | ||
287 | if torch.cuda.is_available(): | 291 | if torch.cuda.is_available(): |
288 | torch.cuda.empty_cache() | 292 | torch.cuda.empty_cache() |
@@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd): | |||
312 | 316 | ||
313 | try: | 317 | try: |
314 | args = run_parser(self.parser, default_cmds, elements) | 318 | args = run_parser(self.parser, default_cmds, elements) |
319 | |||
320 | if len(args.prompt) == 0: | ||
321 | print('Try again with a prompt!') | ||
322 | return | ||
315 | except SystemExit: | 323 | except SystemExit: |
316 | self.parser.print_help() | 324 | self.parser.print_help() |
317 | except Exception as e: | 325 | except Exception as e: |
318 | print(e) | 326 | print(e) |
319 | return | 327 | return |
320 | 328 | ||
321 | if len(args.prompt) == 0: | ||
322 | print('Try again with a prompt!') | ||
323 | |||
324 | try: | 329 | try: |
325 | generate(self.output_dir, self.pipeline, args) | 330 | generate(self.output_dir, self.pipeline, args) |
326 | except KeyboardInterrupt: | 331 | except KeyboardInterrupt: |