diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 |
| commit | 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch) | |
| tree | f490b4794366e78f7b079eb04de1c7c00e17d34a /infer.py | |
| parent | Fix small details (diff) | |
| download | textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2 textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip | |
Update
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 12 |
1 files changed, 7 insertions, 5 deletions
| @@ -5,12 +5,11 @@ import sys | |||
| 5 | import shlex | 5 | import shlex |
| 6 | import cmd | 6 | import cmd |
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | from torch import autocast | ||
| 9 | import torch | 8 | import torch |
| 10 | import json | 9 | import json |
| 11 | from PIL import Image | 10 | from PIL import Image |
| 12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from slugify import slugify | 13 | from slugify import slugify |
| 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 16 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| @@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 22 | default_args = { | 21 | default_args = { |
| 23 | "model": None, | 22 | "model": None, |
| 24 | "scheduler": "euler_a", | 23 | "scheduler": "euler_a", |
| 25 | "precision": "bf16", | 24 | "precision": "fp16", |
| 26 | "embeddings_dir": "embeddings", | 25 | "embeddings_dir": "embeddings", |
| 27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
| 28 | "config": None, | 27 | "config": None, |
| @@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args): | |||
| 260 | else: | 259 | else: |
| 261 | init_image = None | 260 | init_image = None |
| 262 | 261 | ||
| 263 | with autocast("cuda"): | 262 | with torch.autocast("cuda"), torch.inference_mode(): |
| 264 | for i in range(args.batch_num): | 263 | for i in range(args.batch_num): |
| 265 | pipeline.set_progress_bar_config( | 264 | pipeline.set_progress_bar_config( |
| 266 | desc=f"Batch {i + 1} of {args.batch_num}", | 265 | desc=f"Batch {i + 1} of {args.batch_num}", |
| @@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd): | |||
| 313 | args = run_parser(self.parser, default_cmds, elements) | 312 | args = run_parser(self.parser, default_cmds, elements) |
| 314 | except SystemExit: | 313 | except SystemExit: |
| 315 | self.parser.print_help() | 314 | self.parser.print_help() |
| 315 | except Exception as e: | ||
| 316 | print(e) | ||
| 317 | return | ||
| 316 | 318 | ||
| 317 | if len(args.prompt) == 0: | 319 | if len(args.prompt) == 0: |
| 318 | print('Try again with a prompt!') | 320 | print('Try again with a prompt!') |
