From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- infer.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 507d0cf..9c27db4 100644 --- a/infer.py +++ b/infer.py @@ -25,6 +25,7 @@ from diffusers import ( ) from transformers import CLIPTextModel +from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -49,6 +50,7 @@ default_cmds = { "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, + "shuffle": True, "image": None, "image_noise": .7, "width": 768, @@ -125,6 +127,10 @@ def create_cmd_parser(): type=str, nargs="*", ) + parser.add_argument( + "--shuffle", + type=bool, + ) parser.add_argument( "--image", type=str, @@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir): pipeline.text_encoder.text_model.embeddings, Path(embeddings_dir) ) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") def create_pipeline(model, dtype): @@ -228,20 +234,35 @@ def create_pipeline(model, dtype): @torch.inference_mode() -def generate(output_dir, pipeline, args): +def generate(output_dir: Path, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] + if args.shuffle: + args.prompt *= args.batch_size + args.batch_size = 1 + args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - use_subdirs = len(args.prompt) != 1 - if use_subdirs: + image_dir = [] + + if len(args.prompt) != 1: if len(args.project) != 0: output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") else: output_dir = output_dir.joinpath(now) + + for prompt in args.prompt: + dir = output_dir.joinpath(slugify(prompt)[:100]) + dir.mkdir(parents=True, exist_ok=True) + image_dir.append(dir) + + with open(dir.joinpath('prompt.txt'), 'w') as f: + f.write(prompt) else: output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") - output_dir.mkdir(parents=True, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) + image_dir.append(output_dir) args.seed = args.seed or torch.random.seed() @@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image_dir = output_dir - if use_subdirs: - image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) - image_dir.mkdir(parents=True, exist_ok=True) - image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) - image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) + dir = image_dir[j % len(args.prompt)] + image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) + image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() -- cgit v1.2.3-54-g00ecf