diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 | 
| commit | 728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch) | |
| tree | 9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /infer.py | |
| parent | Upstream updates; better handling of textual embedding (diff) | |
| download | textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2 textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip | |
Implemented extended prompt limit
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 12 | 
1 files changed, 11 insertions, 1 deletions
| @@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True | 
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | line_sep = " <OR> " | ||
| 23 | |||
| 24 | |||
| 22 | default_args = { | 25 | default_args = { | 
| 23 | "model": None, | 26 | "model": None, | 
| 24 | "scheduler": "euler_a", | 27 | "scheduler": "euler_a", | 
| @@ -95,10 +98,12 @@ def create_cmd_parser(): | |||
| 95 | parser.add_argument( | 98 | parser.add_argument( | 
| 96 | "--prompt", | 99 | "--prompt", | 
| 97 | type=str, | 100 | type=str, | 
| 101 | nargs="+", | ||
| 98 | ) | 102 | ) | 
| 99 | parser.add_argument( | 103 | parser.add_argument( | 
| 100 | "--negative_prompt", | 104 | "--negative_prompt", | 
| 101 | type=str, | 105 | type=str, | 
| 106 | nargs="*", | ||
| 102 | ) | 107 | ) | 
| 103 | parser.add_argument( | 108 | parser.add_argument( | 
| 104 | "--image", | 109 | "--image", | 
| @@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args): | |||
| 271 | dynamic_ncols=True | 276 | dynamic_ncols=True | 
| 272 | ) | 277 | ) | 
| 273 | 278 | ||
| 279 | if isinstance(args.prompt, str): | ||
| 280 | args.prompt = [args.prompt] | ||
| 281 | |||
| 282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
| 283 | |||
| 274 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 
| 275 | images = pipeline( | 285 | images = pipeline( | 
| 276 | prompt=[args.prompt] * args.batch_size, | 286 | prompt=prompt, | 
| 277 | height=args.height, | 287 | height=args.height, | 
| 278 | width=args.width, | 288 | width=args.width, | 
| 279 | negative_prompt=args.negative_prompt, | 289 | negative_prompt=args.negative_prompt, | 
