diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 12 |
1 files changed, 7 insertions, 5 deletions
@@ -5,12 +5,11 @@ import sys | |||
5 | import shlex | 5 | import shlex |
6 | import cmd | 6 | import cmd |
7 | from pathlib import Path | 7 | from pathlib import Path |
8 | from torch import autocast | ||
9 | import torch | 8 | import torch |
10 | import json | 9 | import json |
11 | from PIL import Image | 10 | from PIL import Image |
12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from slugify import slugify | 13 | from slugify import slugify |
15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
16 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
@@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
22 | default_args = { | 21 | default_args = { |
23 | "model": None, | 22 | "model": None, |
24 | "scheduler": "euler_a", | 23 | "scheduler": "euler_a", |
25 | "precision": "bf16", | 24 | "precision": "fp16", |
26 | "embeddings_dir": "embeddings", | 25 | "embeddings_dir": "embeddings", |
27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
28 | "config": None, | 27 | "config": None, |
@@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args): | |||
260 | else: | 259 | else: |
261 | init_image = None | 260 | init_image = None |
262 | 261 | ||
263 | with autocast("cuda"): | 262 | with torch.autocast("cuda"), torch.inference_mode(): |
264 | for i in range(args.batch_num): | 263 | for i in range(args.batch_num): |
265 | pipeline.set_progress_bar_config( | 264 | pipeline.set_progress_bar_config( |
266 | desc=f"Batch {i + 1} of {args.batch_num}", | 265 | desc=f"Batch {i + 1} of {args.batch_num}", |
@@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd): | |||
313 | args = run_parser(self.parser, default_cmds, elements) | 312 | args = run_parser(self.parser, default_cmds, elements) |
314 | except SystemExit: | 313 | except SystemExit: |
315 | self.parser.print_help() | 314 | self.parser.print_help() |
315 | except Exception as e: | ||
316 | print(e) | ||
317 | return | ||
316 | 318 | ||
317 | if len(args.prompt) == 0: | 319 | if len(args.prompt) == 0: |
318 | print('Try again with a prompt!') | 320 | print('Try again with a prompt!') |