import argparse import datetime import logging import sys import shlex import cmd from pathlib import Path from typing import Optional import torch import json import traceback from PIL import Image from slugify import slugify from diffusers import ( AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, ) from peft import LoraConfig, LoraModel, set_peft_model_state_dict from safetensors.torch import load_file from transformers import CLIPTextModel from data.keywords import str_to_keywords, keywords_to_str from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from util.files import load_config, load_embeddings_from_dir from util.ti import load_embeddings torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "lora_embedding": None, "output_dir": "output/inference", "config": None, } default_cmds = { "project": "", "scheduler": "unipc", "subscheduler": None, "template": "{}", "prompt": None, "negative_prompt": None, "shuffle": False, "image": None, "image_noise": 0.7, "width": 768, "height": 768, "batch_size": 1, "batch_num": 1, "steps": 30, "guidance_scale": 7.0, "sag_scale": 0, "seed": None, "config": None, } def merge_dicts(d1, *args): d1 = d1.copy() for d in args: d1.update({k: v for (k, v) in d.items() if v is not None}) return d1 def create_args_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--model", type=str, ) parser.add_argument( "--precision", type=str, choices=["fp32", "fp16", "bf16"], ) parser.add_argument( "--ti_embeddings_dir", type=str, ) parser.add_argument( "--lora_embedding", type=str, ) parser.add_argument( "--output_dir", type=str, ) parser.add_argument( "--config", type=str, ) return parser def create_cmd_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--scheduler", type=str, choices=[ "plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc", ], ) parser.add_argument( "--subscheduler", type=str, default=None, choices=[ "plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", ], ) parser.add_argument( "--template", type=str, ) parser.add_argument( "--prompt", type=str, nargs="+", ) parser.add_argument( "--negative_prompt", type=str, nargs="*", ) parser.add_argument( "--shuffle", type=bool, ) parser.add_argument( "--image", type=str, ) parser.add_argument( "--image_noise", type=float, ) parser.add_argument( "--width", type=int, ) parser.add_argument( "--height", type=int, ) parser.add_argument( "--batch_size", type=int, ) parser.add_argument( "--batch_num", type=int, ) parser.add_argument( "--steps", type=int, ) parser.add_argument( "--guidance_scale", type=float, ) parser.add_argument( "--sag_scale", type=float, ) parser.add_argument( "--seed", type=int, ) parser.add_argument( "--config", type=str, ) return parser def run_parser(parser, defaults, input=None): args = parser.parse_known_args(input)[0] conf_args = argparse.Namespace() if args.config is not None: conf_args = load_config(args.config) conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ 0 ] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: res.update({k: v for (k, v) in dict.items() if v is not None}) return argparse.Namespace(**res) def save_args(basepath, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(f"{basepath}/args.json", "w") as f: json.dump(info, f, indent=4) def load_embeddings_dir(pipeline, embeddings_dir): added_tokens, added_ids = load_embeddings_from_dir( pipeline.tokenizer, pipeline.text_encoder.text_model.embeddings, Path(embeddings_dir), ) pipeline.text_encoder.text_model.embeddings.persist() print( f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" ) def load_lora(pipeline, path): if path is None: return path = Path(path) with open(path / "lora_config.json", "r") as f: lora_config = json.load(f) tensor_files = list(path.glob("*_end.safetensors")) if len(tensor_files) == 0: return lora_checkpoint_sd = load_file(path / tensor_files[0]) unet_lora_ds = { k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k } text_encoder_lora_ds = { k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k } ti_lora_ds = { k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k } unet_config = LoraConfig(**lora_config["peft_config"]) pipeline.unet = LoraModel(unet_config, pipeline.unet) set_peft_model_state_dict(pipeline.unet, unet_lora_ds) if "text_encoder_peft_config" in lora_config: text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) tokens = [k for k, _ in ti_lora_ds] token_embeddings = [v for _, v in ti_lora_ds] added_tokens, added_ids = load_embeddings( tokenizer=pipeline.tokenizer, embeddings=pipeline.text_encoder.text_model.embeddings, tokens=tokens, token_embeddings=token_embeddings, ) pipeline.text_encoder.text_model.embeddings.persist() print( f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" ) return def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): if scheduler == "plms": return PNDMScheduler.from_config(config) elif scheduler == "klms": return LMSDiscreteScheduler.from_config(config) elif scheduler == "ddim": return DDIMScheduler.from_config(config) elif scheduler == "dpmsm": return DPMSolverMultistepScheduler.from_config(config) elif scheduler == "dpmss": return DPMSolverSinglestepScheduler.from_config(config) elif scheduler == "euler_a": return EulerAncestralDiscreteScheduler.from_config(config) elif scheduler == "kdpm2": return KDPM2DiscreteScheduler.from_config(config) elif scheduler == "kdpm2_a": return KDPM2AncestralDiscreteScheduler.from_config(config) elif scheduler == "deis": return DEISMultistepScheduler.from_config(config) elif scheduler == "unipc": if subscheduler is None: return UniPCMultistepScheduler.from_config(config) else: return UniPCMultistepScheduler.from_config( config, solver_p=create_scheduler(config, subscheduler), ) else: raise ValueError(f'Unknown scheduler "{scheduler}"') def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = MultiCLIPTokenizer.from_pretrained( model, subfolder="tokenizer", torch_dtype=dtype ) text_encoder = CLIPTextModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype ) vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained( model, subfolder="unet", torch_dtype=dtype ) scheduler = DDIMScheduler.from_pretrained( model, subfolder="scheduler", torch_dtype=dtype ) patch_managed_embeddings(text_encoder) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, ) pipeline.enable_xformers_memory_efficient_attention() # pipeline.unet = torch.compile(pipeline.unet) pipeline.enable_vae_slicing() pipeline.to("cuda") print("Pipeline loaded.") return pipeline def shuffle_prompts(prompts: list[str]) -> list[str]: return [ keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts ] @torch.inference_mode() def generate(output_dir: Path, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] args.prompt = [args.template.format(prompt) for prompt in args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") image_dir = [] if len(args.prompt) != 1: if len(args.project) != 0: output_dir = output_dir / f"{now}_{slugify(args.project)}" else: output_dir = output_dir / now for prompt in args.prompt: dir = output_dir / slugify(prompt)[:100] dir.mkdir(parents=True, exist_ok=True) image_dir.append(dir) else: output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}" output_dir.mkdir(parents=True, exist_ok=True) image_dir.append(output_dir) args.seed = args.seed or torch.random.seed() save_args(output_dir, args) if args.image: init_image = Image.open(args.image) if not init_image.mode == "RGB": init_image = init_image.convert("RGB") else: init_image = None pipeline.scheduler = create_scheduler( pipeline.scheduler.config, args.scheduler, args.subscheduler ) for i in range(args.batch_num): pipeline.set_progress_bar_config( desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True ) seed = args.seed + i prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt generator = torch.Generator(device="cuda").manual_seed(seed) images = pipeline( prompt=prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_images_per_prompt=args.batch_size, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, sag_scale=args.sag_scale, generator=generator, image=init_image, strength=args.image_noise, ).images for j, image in enumerate(images): basename = f"{seed}_{j // len(args.prompt)}" dir = image_dir[j % len(args.prompt)] image.save(dir / f"{basename}.png") image.save(dir / f"{basename}.jpg", quality=85) with open(dir / f"{basename}.txt", "w") as f: f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): torch.cuda.empty_cache() class CmdParse(cmd.Cmd): prompt = "dream> " commands = [] def __init__( self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser ): super().__init__() self.output_dir = output_dir self.ti_embeddings_dir = ti_embeddings_dir self.lora_embeddings_dir = lora_embeddings_dir self.pipeline = pipeline self.parser = parser def default(self, line): line = line.replace("'", "\\'") try: elements = shlex.split(line) except ValueError as e: print(str(e)) return if elements[0] == "q": return True if elements[0] == "reload_embeddings": load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) return try: args = run_parser(self.parser, default_cmds, elements) if len(args.prompt) == 0: print("Try again with a prompt!") return except SystemExit: traceback.print_exc() self.parser.print_help() return except Exception as e: traceback.print_exc() return try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print("Generation cancelled.") except Exception as e: traceback.print_exc() return def do_exit(self, line): return True def main(): logging.basicConfig(stream=sys.stdout, level=logging.WARN) args_parser = create_args_parser() args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ args.precision ] pipeline = create_pipeline(args.model, dtype) load_embeddings_dir(pipeline, args.ti_embeddings_dir) load_lora(pipeline, args.lora_embedding) # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse( output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser, ) cmd_prompt.cmdloop() if __name__ == "__main__": main()