summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-17 22:08:58 +0200
committerVolpeon <git@volpeon.ink>2022-10-17 22:08:58 +0200
commit728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch)
tree9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /infer.py
parentUpstream updates; better handling of textual embedding (diff)
downloadtextual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz
textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2
textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip
Implemented extended prompt limit
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/infer.py b/infer.py
index 1a0baf5..d744768 100644
--- a/infer.py
+++ b/infer.py
@@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19torch.backends.cuda.matmul.allow_tf32 = True 19torch.backends.cuda.matmul.allow_tf32 = True
20 20
21 21
22line_sep = " <OR> "
23
24
22default_args = { 25default_args = {
23 "model": None, 26 "model": None,
24 "scheduler": "euler_a", 27 "scheduler": "euler_a",
@@ -95,10 +98,12 @@ def create_cmd_parser():
95 parser.add_argument( 98 parser.add_argument(
96 "--prompt", 99 "--prompt",
97 type=str, 100 type=str,
101 nargs="+",
98 ) 102 )
99 parser.add_argument( 103 parser.add_argument(
100 "--negative_prompt", 104 "--negative_prompt",
101 type=str, 105 type=str,
106 nargs="*",
102 ) 107 )
103 parser.add_argument( 108 parser.add_argument(
104 "--image", 109 "--image",
@@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args):
271 dynamic_ncols=True 276 dynamic_ncols=True
272 ) 277 )
273 278
279 if isinstance(args.prompt, str):
280 args.prompt = [args.prompt]
281
282 prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size
283
274 generator = torch.Generator(device="cuda").manual_seed(seed + i) 284 generator = torch.Generator(device="cuda").manual_seed(seed + i)
275 images = pipeline( 285 images = pipeline(
276 prompt=[args.prompt] * args.batch_size, 286 prompt=prompt,
277 height=args.height, 287 height=args.height,
278 width=args.width, 288 width=args.width,
279 negative_prompt=args.negative_prompt, 289 negative_prompt=args.negative_prompt,