import argparse import datetime import logging import sys import shlex import cmd from pathlib import Path from torch import autocast import torch import json from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion from schedulers.scheduling_euler_a import EulerAScheduler def create_args_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--model", type=str, default=None, ) parser.add_argument( "--scheduler", type=str, choices=["plms", "ddim", "klms", "euler_a"], default="euler_a", ) parser.add_argument( "--output_dir", type=str, default="output/inference", ) parser.add_argument( "--config", type=str, default=None, ) return parser def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--prompt", type=str, default=None, ) parser.add_argument( "--negative_prompt", type=str, default=None, ) parser.add_argument( "--width", type=int, default=512, ) parser.add_argument( "--height", type=int, default=512, ) parser.add_argument( "--batch_size", type=int, default=1, ) parser.add_argument( "--batch_num", type=int, default=1, ) parser.add_argument( "--steps", type=int, default=70, ) parser.add_argument( "--guidance_scale", type=int, default=7, ) parser.add_argument( "--seed", type=int, default=torch.random.seed(), ) parser.add_argument( "--config", type=str, default=None, ) return parser def run_parser(parser, input=None): args = parser.parse_known_args(input)[0] if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_known_args( namespace=argparse.Namespace(**json.load(f)["args"]))[0] return args def save_args(basepath, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(f"{basepath}/args.json", "w") as f: json.dump(info, f, indent=4) def create_pipeline(model, scheduler, dtype=torch.bfloat16): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) if scheduler == "plms": scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ) elif scheduler == "klms": scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) elif scheduler == "ddim": scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) else: scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) pipeline = CLIPGuidedStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() pipeline.to("cuda") print("Pipeline loaded.") return pipeline def generate(output_dir, pipeline, args): now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") output_dir.mkdir(parents=True, exist_ok=True) save_args(output_dir, args) with autocast("cuda"): for i in range(args.batch_num): generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( prompt=[args.prompt] * args.batch_size, height=args.height, width=args.width, negative_prompt=args.negative_prompt, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=generator, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] def __init__(self, output_dir, pipeline, parser): super().__init__() self.output_dir = output_dir self.pipeline = pipeline self.parser = parser def default(self, line): line = line.replace("'", "\\'") try: elements = shlex.split(line) except ValueError as e: print(str(e)) if elements[0] == 'q': return True try: args = run_parser(self.parser, elements) except SystemExit: self.parser.print_help() if len(args.prompt) == 0: print('Try again with a prompt!') try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print('Generation cancelled.') def do_exit(self, line): return True def main(): logging.basicConfig(stream=sys.stdout, level=logging.WARN) args_parser = create_args_parser() args = run_parser(args_parser) output_dir = Path(args.output_dir) pipeline = create_pipeline(args.model, args.scheduler) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() if __name__ == "__main__": main()