From 3ee13893f9a4973ac75f45fe9318c35760dd4b1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 13:57:46 +0100 Subject: Added progressive aspect ratio bucketing --- infer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index d3d5f1b..2b07b21 100644 --- a/infer.py +++ b/infer.py @@ -238,16 +238,15 @@ def create_pipeline(model, dtype): return pipeline +def shuffle_prompts(prompts: list[str]) -> list[str]: + return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] + + @torch.inference_mode() def generate(output_dir: Path, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] - if args.shuffle: - args.prompt *= args.batch_size - args.batch_size = 1 - args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] - args.prompt = [args.template.format(prompt) for prompt in args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") @@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): dir = output_dir.joinpath(slugify(prompt)[:100]) dir.mkdir(parents=True, exist_ok=True) image_dir.append(dir) - - with open(dir.joinpath('prompt.txt'), 'w') as f: - f.write(prompt) else: output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) @@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): ) seed = args.seed + i + prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt generator = torch.Generator(device="cuda").manual_seed(seed) images = pipeline( - prompt=args.prompt, + prompt=prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, @@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): ).images for j, image in enumerate(images): + basename = f"{seed}_{j // len(args.prompt)}" dir = image_dir[j % len(args.prompt)] - image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) - image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) + + image.save(dir.joinpath(f"{basename}.png")) + image.save(dir.joinpath(f"{basename}.jpg"), quality=85) + with open(dir.joinpath(f"{basename}.txt"), 'w') as f: + f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): torch.cuda.empty_cache() -- cgit v1.2.3-54-g00ecf