From 728dfcf57c30f40236b3a00d7380c4e0057cacb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 22:08:58 +0200 Subject: Implemented extended prompt limit --- infer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 1a0baf5..d744768 100644 --- a/infer.py +++ b/infer.py @@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True +line_sep = " " + + default_args = { "model": None, "scheduler": "euler_a", @@ -95,10 +98,12 @@ def create_cmd_parser(): parser.add_argument( "--prompt", type=str, + nargs="+", ) parser.add_argument( "--negative_prompt", type=str, + nargs="*", ) parser.add_argument( "--image", @@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args): dynamic_ncols=True ) + if isinstance(args.prompt, str): + args.prompt = [args.prompt] + + prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size + generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( - prompt=[args.prompt] * args.batch_size, + prompt=prompt, height=args.height, width=args.width, negative_prompt=args.negative_prompt, -- cgit v1.2.3-54-g00ecf