diff options
author | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
commit | 728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch) | |
tree | 9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /infer.py | |
parent | Upstream updates; better handling of textual embedding (diff) | |
download | textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2 textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip |
Implemented extended prompt limit
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 12 |
1 files changed, 11 insertions, 1 deletions
@@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
20 | 20 | ||
21 | 21 | ||
22 | line_sep = " <OR> " | ||
23 | |||
24 | |||
22 | default_args = { | 25 | default_args = { |
23 | "model": None, | 26 | "model": None, |
24 | "scheduler": "euler_a", | 27 | "scheduler": "euler_a", |
@@ -95,10 +98,12 @@ def create_cmd_parser(): | |||
95 | parser.add_argument( | 98 | parser.add_argument( |
96 | "--prompt", | 99 | "--prompt", |
97 | type=str, | 100 | type=str, |
101 | nargs="+", | ||
98 | ) | 102 | ) |
99 | parser.add_argument( | 103 | parser.add_argument( |
100 | "--negative_prompt", | 104 | "--negative_prompt", |
101 | type=str, | 105 | type=str, |
106 | nargs="*", | ||
102 | ) | 107 | ) |
103 | parser.add_argument( | 108 | parser.add_argument( |
104 | "--image", | 109 | "--image", |
@@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args): | |||
271 | dynamic_ncols=True | 276 | dynamic_ncols=True |
272 | ) | 277 | ) |
273 | 278 | ||
279 | if isinstance(args.prompt, str): | ||
280 | args.prompt = [args.prompt] | ||
281 | |||
282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
283 | |||
274 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
275 | images = pipeline( | 285 | images = pipeline( |
276 | prompt=[args.prompt] * args.batch_size, | 286 | prompt=prompt, |
277 | height=args.height, | 287 | height=args.height, |
278 | width=args.width, | 288 | width=args.width, |
279 | negative_prompt=args.negative_prompt, | 289 | negative_prompt=args.negative_prompt, |