diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 15 |
1 files changed, 5 insertions, 10 deletions
@@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
20 | 20 | ||
21 | 21 | ||
22 | line_sep = " <OR> " | ||
23 | |||
24 | |||
25 | default_args = { | 22 | default_args = { |
26 | "model": None, | 23 | "model": None, |
27 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
@@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
254 | 251 | ||
255 | 252 | ||
256 | def generate(output_dir, pipeline, args): | 253 | def generate(output_dir, pipeline, args): |
254 | if isinstance(args.prompt, str): | ||
255 | args.prompt = [args.prompt] | ||
256 | |||
257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
260 | 260 | ||
261 | seed = args.seed or torch.random.seed() | 261 | seed = args.seed or torch.random.seed() |
@@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args): | |||
276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
277 | ) | 277 | ) |
278 | 278 | ||
279 | if isinstance(args.prompt, str): | ||
280 | args.prompt = [args.prompt] | ||
281 | |||
282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
283 | |||
284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
285 | images = pipeline( | 280 | images = pipeline( |
286 | prompt=prompt, | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
287 | height=args.height, | 282 | height=args.height, |
288 | width=args.width, | 283 | width=args.width, |
289 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |