diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 23 |
1 files changed, 12 insertions, 11 deletions
| @@ -238,16 +238,15 @@ def create_pipeline(model, dtype): | |||
| 238 | return pipeline | 238 | return pipeline |
| 239 | 239 | ||
| 240 | 240 | ||
| 241 | def shuffle_prompts(prompts: list[str]) -> list[str]: | ||
| 242 | return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] | ||
| 243 | |||
| 244 | |||
| 241 | @torch.inference_mode() | 245 | @torch.inference_mode() |
| 242 | def generate(output_dir: Path, pipeline, args): | 246 | def generate(output_dir: Path, pipeline, args): |
| 243 | if isinstance(args.prompt, str): | 247 | if isinstance(args.prompt, str): |
| 244 | args.prompt = [args.prompt] | 248 | args.prompt = [args.prompt] |
| 245 | 249 | ||
| 246 | if args.shuffle: | ||
| 247 | args.prompt *= args.batch_size | ||
| 248 | args.batch_size = 1 | ||
| 249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
| 250 | |||
| 251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | 250 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] |
| 252 | 251 | ||
| 253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 252 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| @@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): | |||
| 263 | dir = output_dir.joinpath(slugify(prompt)[:100]) | 262 | dir = output_dir.joinpath(slugify(prompt)[:100]) |
| 264 | dir.mkdir(parents=True, exist_ok=True) | 263 | dir.mkdir(parents=True, exist_ok=True) |
| 265 | image_dir.append(dir) | 264 | image_dir.append(dir) |
| 266 | |||
| 267 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
| 268 | f.write(prompt) | ||
| 269 | else: | 265 | else: |
| 270 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 266 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 271 | output_dir.mkdir(parents=True, exist_ok=True) | 267 | output_dir.mkdir(parents=True, exist_ok=True) |
| @@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): | |||
| 306 | ) | 302 | ) |
| 307 | 303 | ||
| 308 | seed = args.seed + i | 304 | seed = args.seed + i |
| 305 | prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt | ||
| 309 | generator = torch.Generator(device="cuda").manual_seed(seed) | 306 | generator = torch.Generator(device="cuda").manual_seed(seed) |
| 310 | images = pipeline( | 307 | images = pipeline( |
| 311 | prompt=args.prompt, | 308 | prompt=prompt, |
| 312 | negative_prompt=args.negative_prompt, | 309 | negative_prompt=args.negative_prompt, |
| 313 | height=args.height, | 310 | height=args.height, |
| 314 | width=args.width, | 311 | width=args.width, |
| @@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): | |||
| 321 | ).images | 318 | ).images |
| 322 | 319 | ||
| 323 | for j, image in enumerate(images): | 320 | for j, image in enumerate(images): |
| 321 | basename = f"{seed}_{j // len(args.prompt)}" | ||
| 324 | dir = image_dir[j % len(args.prompt)] | 322 | dir = image_dir[j % len(args.prompt)] |
| 325 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | 323 | |
| 326 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | 324 | image.save(dir.joinpath(f"{basename}.png")) |
| 325 | image.save(dir.joinpath(f"{basename}.jpg"), quality=85) | ||
| 326 | with open(dir.joinpath(f"{basename}.txt"), 'w') as f: | ||
| 327 | f.write(prompt[j % len(args.prompt)]) | ||
| 327 | 328 | ||
| 328 | if torch.cuda.is_available(): | 329 | if torch.cuda.is_available(): |
| 329 | torch.cuda.empty_cache() | 330 | torch.cuda.empty_cache() |
