From 5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Oct 2022 11:40:14 +0200 Subject: Made inference script interactive --- infer.py | 154 ++++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 109 insertions(+), 45 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index de3d792..40720ea 100644 --- a/infer.py +++ b/infer.py @@ -1,18 +1,21 @@ import argparse import datetime import logging +import sys +import shlex +import cmd from pathlib import Path from torch import autocast import torch import json from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler -from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion from schedulers.scheduling_euler_a import EulerAScheduler -def parse_args(): +def create_args_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) @@ -21,6 +24,30 @@ def parse_args(): type=str, default=None, ) + parser.add_argument( + "--scheduler", + type=str, + choices=["plms", "ddim", "klms", "euler_a"], + default="euler_a", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/inference", + ) + parser.add_argument( + "--config", + type=str, + default=None, + ) + + return parser + + +def create_cmd_parser(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) parser.add_argument( "--prompt", type=str, @@ -49,50 +76,39 @@ def parse_args(): parser.add_argument( "--batch_num", type=int, - default=50, + default=1, ) parser.add_argument( "--steps", type=int, - default=120, - ) - parser.add_argument( - "--scheduler", - type=str, - choices=["plms", "ddim", "klms", "euler_a"], - default="euler_a", + default=70, ) parser.add_argument( "--guidance_scale", type=int, - default=7.5, - ) - parser.add_argument( - "--clip_guidance_scale", - type=int, - default=100, + default=7, ) parser.add_argument( "--seed", type=int, default=torch.random.seed(), ) - parser.add_argument( - "--output_dir", - type=str, - default="output/inference", - ) parser.add_argument( "--config", type=str, default=None, ) - args = parser.parse_args() + return parser + + +def run_parser(parser, input=None): + args = parser.parse_known_args(input)[0] + if args.config is not None: with open(args.config, 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) + args = parser.parse_known_args( + namespace=argparse.Namespace(**json.load(f)["args"]))[0] return args @@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def gen(args, output_dir): - tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) - text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) - clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) - vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) - unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) - feature_extractor = CLIPFeatureExtractor.from_pretrained( - "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) +def create_pipeline(model, scheduler, dtype=torch.bfloat16): + print("Loading Stable Diffusion pipeline...") - if args.scheduler == "plms": + tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) + text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) + vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) + feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) + + if scheduler == "plms": scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ) - elif args.scheduler == "klms": + elif scheduler == "klms": scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) - elif args.scheduler == "ddim": + elif scheduler == "ddim": scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) @@ -135,13 +151,24 @@ def gen(args, output_dir): vae=vae, unet=unet, tokenizer=tokenizer, - clip_model=clip_model, scheduler=scheduler, feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() pipeline.to("cuda") + print("Pipeline loaded.") + + return pipeline + + +def generate(output_dir, pipeline, args): + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") + output_dir.mkdir(parents=True, exist_ok=True) + + save_args(output_dir, args) + with autocast("cuda"): for i in range(args.batch_num): generator = torch.Generator(device="cuda").manual_seed(args.seed + i) @@ -152,7 +179,6 @@ def gen(args, output_dir): negative_prompt=args.negative_prompt, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, - clip_guidance_scale=args.clip_guidance_scale, generator=generator, ).images @@ -160,18 +186,56 @@ def gen(args, output_dir): image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) -def main(): - args = parse_args() +class CmdParse(cmd.Cmd): + prompt = 'dream> ' + commands = [] - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") - output_dir.mkdir(parents=True, exist_ok=True) + def __init__(self, output_dir, pipeline, parser): + super().__init__() - save_args(output_dir, args) + self.output_dir = output_dir + self.pipeline = pipeline + self.parser = parser + + def default(self, line): + line = line.replace("'", "\\'") + + try: + elements = shlex.split(line) + except ValueError as e: + print(str(e)) + + if elements[0] == 'q': + return True + + try: + args = run_parser(self.parser, elements) + except SystemExit: + self.parser.print_help() + + if len(args.prompt) == 0: + print('Try again with a prompt!') + + try: + generate(self.output_dir, self.pipeline, args) + except KeyboardInterrupt: + print('Generation cancelled.') + + def do_exit(self, line): + return True + + +def main(): + logging.basicConfig(stream=sys.stdout, level=logging.WARN) - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + args_parser = create_args_parser() + args = run_parser(args_parser) + output_dir = Path(args.output_dir) - gen(args, output_dir) + pipeline = create_pipeline(args.model, args.scheduler) + cmd_parser = create_cmd_parser() + cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) + cmd_prompt.cmdloop() if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf