import argparse
import datetime
import logging
import sys
import shlex
import cmd
from pathlib import Path
from typing import Optional
import torch
import json
import traceback
from PIL import Image
from slugify import slugify
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
PNDMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
DEISMultistepScheduler,
UniPCMultistepScheduler,
)
from peft import LoraConfig, LoraModel, set_peft_model_state_dict
from safetensors.torch import load_file
from transformers import CLIPTextModel
from data.keywords import str_to_keywords, keywords_to_str
from models.clip.embeddings import patch_managed_embeddings
from models.clip.tokenizer import MultiCLIPTokenizer
from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
from util.files import load_config, load_embeddings_from_dir
from util.ti import load_embeddings
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
default_args = {
"model": "stabilityai/stable-diffusion-2-1",
"precision": "fp32",
"ti_embeddings_dir": "embeddings_ti",
"lora_embeddings_dir": None,
"output_dir": "output/inference",
"config": None,
}
default_cmds = {
"project": "",
"scheduler": "unipc",
"subscheduler": None,
"template": "{}",
"prompt": None,
"negative_prompt": None,
"shuffle": False,
"image": None,
"image_noise": 0.7,
"width": 768,
"height": 768,
"batch_size": 1,
"batch_num": 1,
"steps": 30,
"guidance_scale": 7.0,
"sag_scale": 0,
"seed": None,
"config": None,
}
def merge_dicts(d1, *args):
d1 = d1.copy()
for d in args:
d1.update({k: v for (k, v) in d.items() if v is not None})
return d1
def create_args_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--model",
type=str,
)
parser.add_argument(
"--precision",
type=str,
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument(
"--ti_embeddings_dir",
type=str,
)
parser.add_argument(
"--lora_embeddings_dir",
type=str,
)
parser.add_argument(
"--output_dir",
type=str,
)
parser.add_argument(
"--config",
type=str,
)
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--project",
type=str,
default=None,
help="The name of the current project.",
)
parser.add_argument(
"--scheduler",
type=str,
choices=[
"plms",
"ddim",
"klms",
"dpmsm",
"dpmss",
"euler_a",
"kdpm2",
"kdpm2_a",
"deis",
"unipc",
],
)
parser.add_argument(
"--subscheduler",
type=str,
default=None,
choices=[
"plms",
"ddim",
"klms",
"dpmsm",
"dpmss",
"euler_a",
"kdpm2",
"kdpm2_a",
"deis",
],
)
parser.add_argument(
"--template",
type=str,
)
parser.add_argument(
"--prompt",
type=str,
nargs="+",
)
parser.add_argument(
"--negative_prompt",
type=str,
nargs="*",
)
parser.add_argument(
"--shuffle",
type=bool,
)
parser.add_argument(
"--image",
type=str,
)
parser.add_argument(
"--image_noise",
type=float,
)
parser.add_argument(
"--width",
type=int,
)
parser.add_argument(
"--height",
type=int,
)
parser.add_argument(
"--batch_size",
type=int,
)
parser.add_argument(
"--batch_num",
type=int,
)
parser.add_argument(
"--steps",
type=int,
)
parser.add_argument(
"--guidance_scale",
type=float,
)
parser.add_argument(
"--sag_scale",
type=float,
)
parser.add_argument(
"--seed",
type=int,
)
parser.add_argument(
"--config",
type=str,
)
return parser
def run_parser(parser, defaults, input=None):
args = parser.parse_known_args(input)[0]
conf_args = argparse.Namespace()
if args.config is not None:
conf_args = load_config(args.config)
conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[
0
]
res = defaults.copy()
for dict in [vars(conf_args), vars(args)]:
res.update({k: v for (k, v) in dict.items() if v is not None})
return argparse.Namespace(**res)
def save_args(basepath, args, extra={}):
info = {"args": vars(args)}
info["args"].update(extra)
with open(f"{basepath}/args.json", "w") as f:
json.dump(info, f, indent=4)
def load_embeddings_dir(pipeline, embeddings_dir):
added_tokens, added_ids = load_embeddings_from_dir(
pipeline.tokenizer,
pipeline.text_encoder.text_model.embeddings,
Path(embeddings_dir),
)
pipeline.text_encoder.text_model.embeddings.persist()
print(
f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
)
def load_lora(pipeline, path):
if path is None:
return
path = Path(path)
with open(path / "lora_config.json", "r") as f:
lora_config = json.load(f)
tensor_files = list(path.glob("*_end.safetensors"))
if len(tensor_files) == 0:
return
lora_checkpoint_sd = load_file(path / tensor_files[0])
unet_lora_ds = {
k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k
}
text_encoder_lora_ds = {
k.replace("text_encoder_", ""): v
for k, v in lora_checkpoint_sd.items()
if "text_encoder_" in k
}
ti_lora_ds = {
k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
}
unet_config = LoraConfig(**lora_config["peft_config"])
pipeline.unet = LoraModel(unet_config, pipeline.unet)
set_peft_model_state_dict(pipeline.unet, unet_lora_ds)
if "text_encoder_peft_config" in lora_config:
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)
tokens = [k for k, _ in ti_lora_ds]
token_embeddings = [v for _, v in ti_lora_ds]
added_tokens, added_ids = load_embeddings(
tokenizer=pipeline.tokenizer,
embeddings=pipeline.text_encoder.text_model.embeddings,
tokens=tokens,
token_embeddings=token_embeddings,
)
pipeline.text_encoder.text_model.embeddings.persist()
print(
f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}"
)
return
def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None):
if scheduler == "plms":
return PNDMScheduler.from_config(config)
elif scheduler == "klms":
return LMSDiscreteScheduler.from_config(config)
elif scheduler == "ddim":
return DDIMScheduler.from_config(config)
elif scheduler == "dpmsm":
return DPMSolverMultistepScheduler.from_config(config)
elif scheduler == "dpmss":
return DPMSolverSinglestepScheduler.from_config(config)
elif scheduler == "euler_a":
return EulerAncestralDiscreteScheduler.from_config(config)
elif scheduler == "kdpm2":
return KDPM2DiscreteScheduler.from_config(config)
elif scheduler == "kdpm2_a":
return KDPM2AncestralDiscreteScheduler.from_config(config)
elif scheduler == "deis":
return DEISMultistepScheduler.from_config(config)
elif scheduler == "unipc":
if subscheduler is None:
return UniPCMultistepScheduler.from_config(config)
else:
return UniPCMultistepScheduler.from_config(
config,
solver_p=create_scheduler(config, subscheduler),
)
else:
raise ValueError(f'Unknown scheduler "{scheduler}"')
def create_pipeline(model, dtype):
print(f"Loading Stable Diffusion pipeline: {model}...")
tokenizer = MultiCLIPTokenizer.from_pretrained(
model, subfolder="tokenizer", torch_dtype=dtype
)
text_encoder = CLIPTextModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype
)
vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(
model, subfolder="unet", torch_dtype=dtype
)
scheduler = DDIMScheduler.from_pretrained(
model, subfolder="scheduler", torch_dtype=dtype
)
patch_managed_embeddings(text_encoder)
pipeline = VlpnStableDiffusion(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
)
pipeline.enable_xformers_memory_efficient_attention()
# pipeline.unet = torch.compile(pipeline.unet)
pipeline.enable_vae_slicing()
pipeline.to("cuda")
print("Pipeline loaded.")
return pipeline
def shuffle_prompts(prompts: list[str]) -> list[str]:
return [
keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts
]
@torch.inference_mode()
def generate(output_dir: Path, pipeline, args):
if isinstance(args.prompt, str):
args.prompt = [args.prompt]
args.prompt = [args.template.format(prompt) for prompt in args.prompt]
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
image_dir = []
if len(args.prompt) != 1:
if len(args.project) != 0:
output_dir = output_dir / f"{now}_{slugify(args.project)}"
else:
output_dir = output_dir / now
for prompt in args.prompt:
dir = output_dir / slugify(prompt)[:100]
dir.mkdir(parents=True, exist_ok=True)
image_dir.append(dir)
else:
output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}"
output_dir.mkdir(parents=True, exist_ok=True)
image_dir.append(output_dir)
args.seed = args.seed or torch.random.seed()
save_args(output_dir, args)
if args.image:
init_image = Image.open(args.image)
if not init_image.mode == "RGB":
init_image = init_image.convert("RGB")
else:
init_image = None
pipeline.scheduler = create_scheduler(
pipeline.scheduler.config, args.scheduler, args.subscheduler
)
for i in range(args.batch_num):
pipeline.set_progress_bar_config(
desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True
)
seed = args.seed + i
prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt
generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipeline(
prompt=prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
generator=generator,
guidance_scale=args.guidance_scale,
num_images_per_prompt=args.batch_size,
num_inference_steps=args.steps,
sag_scale=args.sag_scale,
image=init_image,
strength=args.image_noise,
).images
for j, image in enumerate(images):
basename = f"{seed}_{j // len(args.prompt)}"
dir = image_dir[j % len(args.prompt)]
image.save(dir / f"{basename}.png")
image.save(dir / f"{basename}.jpg", quality=85)
with open(dir / f"{basename}.txt", "w") as f:
f.write(prompt[j % len(args.prompt)])
if torch.cuda.is_available():
torch.cuda.empty_cache()
class CmdParse(cmd.Cmd):
prompt = "dream> "
commands = []
def __init__(
self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser
):
super().__init__()
self.output_dir = output_dir
self.ti_embeddings_dir = ti_embeddings_dir
self.lora_embeddings_dir = lora_embeddings_dir
self.pipeline = pipeline
self.parser = parser
def default(self, line):
line = line.replace("'", "\\'")
try:
elements = shlex.split(line)
except ValueError as e:
print(str(e))
return
if elements[0] == "q":
return True
if elements[0] == "reload_embeddings":
load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
return
try:
args = run_parser(self.parser, default_cmds, elements)
if len(args.prompt) == 0:
print("Try again with a prompt!")
return
except SystemExit:
traceback.print_exc()
self.parser.print_help()
return
except Exception as e:
traceback.print_exc()
return
try:
generate(self.output_dir, self.pipeline, args)
except KeyboardInterrupt:
print("Generation cancelled.")
except Exception as e:
traceback.print_exc()
return
def do_exit(self, line):
return True
def main():
logging.basicConfig(stream=sys.stdout, level=logging.WARN)
args_parser = create_args_parser()
args = run_parser(args_parser, default_args)
output_dir = Path(args.output_dir)
dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[
args.precision
]
pipeline = create_pipeline(args.model, dtype)
# load_embeddings_dir(pipeline, args.ti_embeddings_dir)
# load_lora(pipeline, args.lora_embeddings_dir)
# pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
cmd_parser = create_cmd_parser()
cmd_prompt = CmdParse(
output_dir,
args.ti_embeddings_dir,
args.lora_embeddings_dir,
pipeline,
cmd_parser,
)
cmd_prompt.cmdloop()
if __name__ == "__main__":
main()