summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
committerVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
commit6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch)
treef490b4794366e78f7b079eb04de1c7c00e17d34a /infer.py
parentFix small details (diff)
downloadtextual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip
Update
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/infer.py b/infer.py
index 6197aa3..a542534 100644
--- a/infer.py
+++ b/infer.py
@@ -5,12 +5,11 @@ import sys
5import shlex 5import shlex
6import cmd 6import cmd
7from pathlib import Path 7from pathlib import Path
8from torch import autocast
9import torch 8import torch
10import json 9import json
11from PIL import Image 10from PIL import Image
12from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
13from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 12from transformers import CLIPTextModel, CLIPTokenizer
14from slugify import slugify 13from slugify import slugify
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 14from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
@@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
22default_args = { 21default_args = {
23 "model": None, 22 "model": None,
24 "scheduler": "euler_a", 23 "scheduler": "euler_a",
25 "precision": "bf16", 24 "precision": "fp16",
26 "embeddings_dir": "embeddings", 25 "embeddings_dir": "embeddings",
27 "output_dir": "output/inference", 26 "output_dir": "output/inference",
28 "config": None, 27 "config": None,
@@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args):
260 else: 259 else:
261 init_image = None 260 init_image = None
262 261
263 with autocast("cuda"): 262 with torch.autocast("cuda"), torch.inference_mode():
264 for i in range(args.batch_num): 263 for i in range(args.batch_num):
265 pipeline.set_progress_bar_config( 264 pipeline.set_progress_bar_config(
266 desc=f"Batch {i + 1} of {args.batch_num}", 265 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd):
313 args = run_parser(self.parser, default_cmds, elements) 312 args = run_parser(self.parser, default_cmds, elements)
314 except SystemExit: 313 except SystemExit:
315 self.parser.print_help() 314 self.parser.print_help()
315 except Exception as e:
316 print(e)
317 return
316 318
317 if len(args.prompt) == 0: 319 if len(args.prompt) == 0:
318 print('Try again with a prompt!') 320 print('Try again with a prompt!')