diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 21 |
1 files changed, 13 insertions, 8 deletions
| @@ -8,11 +8,10 @@ from pathlib import Path | |||
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | from PIL import Image | 10 | from PIL import Image |
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler |
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 17 | 16 | ||
| 18 | 17 | ||
| @@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 21 | 20 | ||
| 22 | default_args = { | 21 | default_args = { |
| 23 | "model": None, | 22 | "model": None, |
| 24 | "scheduler": "euler_a", | 23 | "scheduler": "dpmpp", |
| 25 | "precision": "fp32", | 24 | "precision": "fp32", |
| 26 | "ti_embeddings_dir": "embeddings_ti", | 25 | "ti_embeddings_dir": "embeddings_ti", |
| 27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
| @@ -65,7 +64,7 @@ def create_args_parser(): | |||
| 65 | parser.add_argument( | 64 | parser.add_argument( |
| 66 | "--scheduler", | 65 | "--scheduler", |
| 67 | type=str, | 66 | type=str, |
| 68 | choices=["plms", "ddim", "klms", "euler_a"], | 67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], |
| 69 | ) | 68 | ) |
| 70 | parser.add_argument( | 69 | parser.add_argument( |
| 71 | "--precision", | 70 | "--precision", |
| @@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
| 222 | scheduler = DDIMScheduler( | 221 | scheduler = DDIMScheduler( |
| 223 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
| 224 | ) | 223 | ) |
| 224 | elif scheduler == "dpmpp": | ||
| 225 | scheduler = DPMSolverMultistepScheduler( | ||
| 226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 227 | ) | ||
| 225 | else: | 228 | else: |
| 226 | scheduler = EulerAncestralDiscreteScheduler( | 229 | scheduler = EulerAncestralDiscreteScheduler( |
| 227 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| @@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args): | |||
| 282 | ).images | 285 | ).images |
| 283 | 286 | ||
| 284 | for j, image in enumerate(images): | 287 | for j, image in enumerate(images): |
| 285 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 288 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
| 289 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | ||
| 286 | 290 | ||
| 287 | if torch.cuda.is_available(): | 291 | if torch.cuda.is_available(): |
| 288 | torch.cuda.empty_cache() | 292 | torch.cuda.empty_cache() |
| @@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd): | |||
| 312 | 316 | ||
| 313 | try: | 317 | try: |
| 314 | args = run_parser(self.parser, default_cmds, elements) | 318 | args = run_parser(self.parser, default_cmds, elements) |
| 319 | |||
| 320 | if len(args.prompt) == 0: | ||
| 321 | print('Try again with a prompt!') | ||
| 322 | return | ||
| 315 | except SystemExit: | 323 | except SystemExit: |
| 316 | self.parser.print_help() | 324 | self.parser.print_help() |
| 317 | except Exception as e: | 325 | except Exception as e: |
| 318 | print(e) | 326 | print(e) |
| 319 | return | 327 | return |
| 320 | 328 | ||
| 321 | if len(args.prompt) == 0: | ||
| 322 | print('Try again with a prompt!') | ||
| 323 | |||
| 324 | try: | 329 | try: |
| 325 | generate(self.output_dir, self.pipeline, args) | 330 | generate(self.output_dir, self.pipeline, args) |
| 326 | except KeyboardInterrupt: | 331 | except KeyboardInterrupt: |
