import argparse import datetime import logging import sys import shlex import cmd from pathlib import Path import torch import json from PIL import Image from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, "scheduler": "dpmpp", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", "config": None, } default_cmds = { "prompt": None, "negative_prompt": None, "image": None, "image_noise": .7, "width": 512, "height": 512, "batch_size": 1, "batch_num": 1, "steps": 50, "guidance_scale": 7.0, "seed": None, "config": None, } def merge_dicts(d1, *args): d1 = d1.copy() for d in args: d1.update({k: v for (k, v) in d.items() if v is not None}) return d1 def create_args_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--model", type=str, ) parser.add_argument( "--scheduler", type=str, choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], ) parser.add_argument( "--precision", type=str, choices=["fp32", "fp16", "bf16"], ) parser.add_argument( "--ti_embeddings_dir", type=str, ) parser.add_argument( "--output_dir", type=str, ) parser.add_argument( "--config", type=str, ) return parser def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--prompt", type=str, nargs="+", ) parser.add_argument( "--negative_prompt", type=str, nargs="*", ) parser.add_argument( "--image", type=str, ) parser.add_argument( "--image_noise", type=float, ) parser.add_argument( "--width", type=int, ) parser.add_argument( "--height", type=int, ) parser.add_argument( "--batch_size", type=int, ) parser.add_argument( "--batch_num", type=int, ) parser.add_argument( "--steps", type=int, ) parser.add_argument( "--guidance_scale", type=float, ) parser.add_argument( "--seed", type=int, ) parser.add_argument( "--config", type=str, ) return parser def run_parser(parser, defaults, input=None): args = parser.parse_known_args(input)[0] conf_args = argparse.Namespace() if args.config is not None: with open(args.config, 'rt') as f: conf_args = parser.parse_known_args( namespace=argparse.Namespace(**json.load(f)["args"]))[0] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: res.update({k: v for (k, v) in dict.items() if v is not None}) return argparse.Namespace(**res) def save_args(basepath, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(f"{basepath}/args.json", "w") as f: json.dump(info, f, indent=4) def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): print(f"Loading Textual Inversion embeddings") embeddings_dir = Path(embeddings_dir) embeddings_dir.mkdir(parents=True, exist_ok=True) placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] tokenizer.add_tokens(placeholder_tokens) text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data for file in embeddings_dir.iterdir(): if file.is_file(): placeholder_token = file.stem placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) data = torch.load(file, map_location="cpu") assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) token_embeds[placeholder_token_id] = emb print(f"Loaded {placeholder_token}") def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) if scheduler == "plms": scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ) elif scheduler == "klms": scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) elif scheduler == "ddim": scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) elif scheduler == "dpmpp": scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) else: scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, ) pipeline.enable_xformers_memory_efficient_attention() pipeline.to("cuda") print("Pipeline loaded.") return pipeline def generate(output_dir, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) args.seed = args.seed or torch.random.seed() save_args(output_dir, args) if args.image: init_image = Image.open(args.image) if not init_image.mode == "RGB": init_image = init_image.convert("RGB") else: init_image = None with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True ) generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( prompt=args.prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_images_per_prompt=args.batch_size, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=generator, latents_or_image=init_image, strength=args.image_noise, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] def __init__(self, output_dir, pipeline, parser): super().__init__() self.output_dir = output_dir self.pipeline = pipeline self.parser = parser def default(self, line): line = line.replace("'", "\\'") try: elements = shlex.split(line) except ValueError as e: print(str(e)) if elements[0] == 'q': return True try: args = run_parser(self.parser, default_cmds, elements) if len(args.prompt) == 0: print('Try again with a prompt!') return except SystemExit: self.parser.print_help() except Exception as e: print(e) return try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print('Generation cancelled.') def do_exit(self, line): return True def main(): logging.basicConfig(stream=sys.stdout, level=logging.WARN) args_parser = create_args_parser() args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() if __name__ == "__main__": main()