import argparse import datetime import logging from pathlib import Path from torch import autocast import torch import json from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion from schedulers.scheduling_euler_a import EulerAScheduler def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--model", type=str, default=None, ) parser.add_argument( "--prompt", type=str, default=None, ) parser.add_argument( "--negative_prompt", type=str, default=None, ) parser.add_argument( "--width", type=int, default=512, ) parser.add_argument( "--height", type=int, default=512, ) parser.add_argument( "--batch_size", type=int, default=1, ) parser.add_argument( "--batch_num", type=int, default=50, ) parser.add_argument( "--steps", type=int, default=120, ) parser.add_argument( "--scheduler", type=str, choices=["plms", "ddim", "klms", "euler_a"], default="euler_a", ) parser.add_argument( "--guidance_scale", type=int, default=7.5, ) parser.add_argument( "--clip_guidance_scale", type=int, default=100, ) parser.add_argument( "--seed", type=int, default=torch.random.seed(), ) parser.add_argument( "--output_dir", type=str, default="output/inference", ) parser.add_argument( "--config", type=str, default=None, ) args = parser.parse_args() if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) return args def save_args(basepath, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(f"{basepath}/args.json", "w") as f: json.dump(info, f, indent=4) def gen(args, output_dir): tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) feature_extractor = CLIPFeatureExtractor.from_pretrained( "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) if args.scheduler == "plms": scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ) elif args.scheduler == "klms": scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) elif args.scheduler == "ddim": scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) else: scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) pipeline = CLIPGuidedStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, clip_model=clip_model, scheduler=scheduler, feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() pipeline.to("cuda") with autocast("cuda"): for i in range(args.batch_num): generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( prompt=[args.prompt] * args.batch_size, height=args.height, width=args.width, negative_prompt=args.negative_prompt, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, clip_guidance_scale=args.clip_guidance_scale, generator=generator, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") output_dir.mkdir(parents=True, exist_ok=True) save_args(output_dir, args) logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) gen(args, output_dir) if __name__ == "__main__": main()