import argparse import datetime import logging import sys import shlex import cmd from pathlib import Path import torch import json from PIL import Image from diffusers import ( AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler ) from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True default_args = { "model": None, "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", "config": None, } default_cmds = { "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, "image": None, "image_noise": .7, "width": 512, "height": 512, "batch_size": 1, "batch_num": 1, "steps": 30, "guidance_scale": 7.0, "seed": None, "config": None, } def merge_dicts(d1, *args): d1 = d1.copy() for d in args: d1.update({k: v for (k, v) in d.items() if v is not None}) return d1 def create_args_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--model", type=str, ) parser.add_argument( "--precision", type=str, choices=["fp32", "fp16", "bf16"], ) parser.add_argument( "--ti_embeddings_dir", type=str, ) parser.add_argument( "--output_dir", type=str, ) parser.add_argument( "--config", type=str, ) return parser def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--scheduler", type=str, choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], ) parser.add_argument( "--prompt", type=str, nargs="+", ) parser.add_argument( "--negative_prompt", type=str, nargs="*", ) parser.add_argument( "--image", type=str, ) parser.add_argument( "--image_noise", type=float, ) parser.add_argument( "--width", type=int, ) parser.add_argument( "--height", type=int, ) parser.add_argument( "--batch_size", type=int, ) parser.add_argument( "--batch_num", type=int, ) parser.add_argument( "--steps", type=int, ) parser.add_argument( "--guidance_scale", type=float, ) parser.add_argument( "--seed", type=int, ) parser.add_argument( "--config", type=str, ) return parser def run_parser(parser, defaults, input=None): args = parser.parse_known_args(input)[0] conf_args = argparse.Namespace() if args.config is not None: with open(args.config, 'rt') as f: conf_args = parser.parse_known_args( namespace=argparse.Namespace(**json.load(f)["args"]))[0] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: res.update({k: v for (k, v) in dict.items() if v is not None}) return argparse.Namespace(**res) def save_args(basepath, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(f"{basepath}/args.json", "w") as f: json.dump(info, f, indent=4) def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): print(f"Loading Textual Inversion embeddings") embeddings_dir = Path(embeddings_dir) embeddings_dir.mkdir(parents=True, exist_ok=True) placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] tokenizer.add_tokens(placeholder_tokens) text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data for file in embeddings_dir.iterdir(): if file.is_file(): placeholder_token = file.stem placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) data = torch.load(file, map_location="cpu") assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) token_embeds[placeholder_token_id] = emb print(f"Loaded {placeholder_token}") def create_pipeline(model, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, ) pipeline.enable_xformers_memory_efficient_attention() pipeline.enable_vae_slicing() pipeline.to("cuda") print("Pipeline loaded.") return pipeline def generate(output_dir, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) args.seed = args.seed or torch.random.seed() save_args(output_dir, args) if args.image: init_image = Image.open(args.image) if not init_image.mode == "RGB": init_image = init_image.convert("RGB") else: init_image = None if args.scheduler == "plms": pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "klms": pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "ddim": pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "dpmsm": pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "dpmss": pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "euler_a": pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "kdpm2": pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "kdpm2_a": pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True ) generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( prompt=args.prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_images_per_prompt=args.batch_size, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=generator, image=init_image, strength=args.image_noise, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] def __init__(self, output_dir, pipeline, parser): super().__init__() self.output_dir = output_dir self.pipeline = pipeline self.parser = parser def default(self, line): line = line.replace("'", "\\'") try: elements = shlex.split(line) except ValueError as e: print(str(e)) if elements[0] == 'q': return True try: args = run_parser(self.parser, default_cmds, elements) if len(args.prompt) == 0: print('Try again with a prompt!') return except SystemExit: self.parser.print_help() except Exception as e: print(e) return try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print('Generation cancelled.') except Exception as e: print(e) return def do_exit(self, line): return True def main(): logging.basicConfig(stream=sys.stdout, level=logging.WARN) args_parser = create_args_parser() args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() if __name__ == "__main__": main()