summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py15
1 files changed, 5 insertions, 10 deletions
diff --git a/infer.py b/infer.py
index d744768..8e17c4e 100644
--- a/infer.py
+++ b/infer.py
@@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19torch.backends.cuda.matmul.allow_tf32 = True 19torch.backends.cuda.matmul.allow_tf32 = True
20 20
21 21
22line_sep = " <OR> "
23
24
25default_args = { 22default_args = {
26 "model": None, 23 "model": None,
27 "scheduler": "euler_a", 24 "scheduler": "euler_a",
@@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
254 251
255 252
256def generate(output_dir, pipeline, args): 253def generate(output_dir, pipeline, args):
254 if isinstance(args.prompt, str):
255 args.prompt = [args.prompt]
256
257 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 257 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") 258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
259 output_dir.mkdir(parents=True, exist_ok=True) 259 output_dir.mkdir(parents=True, exist_ok=True)
260 260
261 seed = args.seed or torch.random.seed() 261 seed = args.seed or torch.random.seed()
@@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args):
276 dynamic_ncols=True 276 dynamic_ncols=True
277 ) 277 )
278 278
279 if isinstance(args.prompt, str):
280 args.prompt = [args.prompt]
281
282 prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size
283
284 generator = torch.Generator(device="cuda").manual_seed(seed + i) 279 generator = torch.Generator(device="cuda").manual_seed(seed + i)
285 images = pipeline( 280 images = pipeline(
286 prompt=prompt, 281 prompt=args.prompt * (args.batch_size // len(args.prompt)),
287 height=args.height, 282 height=args.height,
288 width=args.width, 283 width=args.width,
289 negative_prompt=args.negative_prompt, 284 negative_prompt=args.negative_prompt,