diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 15 |
1 files changed, 5 insertions, 10 deletions
| @@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | line_sep = " <OR> " | ||
| 23 | |||
| 24 | |||
| 25 | default_args = { | 22 | default_args = { |
| 26 | "model": None, | 23 | "model": None, |
| 27 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| @@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
| 254 | 251 | ||
| 255 | 252 | ||
| 256 | def generate(output_dir, pipeline, args): | 253 | def generate(output_dir, pipeline, args): |
| 254 | if isinstance(args.prompt, str): | ||
| 255 | args.prompt = [args.prompt] | ||
| 256 | |||
| 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
| 260 | 260 | ||
| 261 | seed = args.seed or torch.random.seed() | 261 | seed = args.seed or torch.random.seed() |
| @@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args): | |||
| 276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
| 277 | ) | 277 | ) |
| 278 | 278 | ||
| 279 | if isinstance(args.prompt, str): | ||
| 280 | args.prompt = [args.prompt] | ||
| 281 | |||
| 282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
| 283 | |||
| 284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
| 285 | images = pipeline( | 280 | images = pipeline( |
| 286 | prompt=prompt, | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
| 287 | height=args.height, | 282 | height=args.height, |
| 288 | width=args.width, | 283 | width=args.width, |
| 289 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |
