import argparse import datetime from pathlib import Path from torch import autocast from diffusers import StableDiffusionPipeline import torch import json from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify from pipelines.stable_diffusion.no_check import NoCheck model_id = "path-to-your-trained-model" prompt = "A photo of sks dog in a bucket" def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--model", type=str, default=None, ) parser.add_argument( "--prompt", type=str, default=None, ) parser.add_argument( "--batch_size", type=int, default=1, ) parser.add_argument( "--batch_num", type=int, default=50, ) parser.add_argument( "--steps", type=int, default=80, ) parser.add_argument( "--scale", type=int, default=7.5, ) parser.add_argument( "--seed", type=int, default=None, ) parser.add_argument( "--output_dir", type=str, default="output/inference", ) parser.add_argument( "--config", type=str, default=None, ) args = parser.parse_args() if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) return args def main(): args = parse_args() seed = args.seed or torch.random.seed() generator = torch.Generator(device="cuda").manual_seed(seed) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") output_dir.mkdir(parents=True, exist_ok=True) tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), safety_checker=NoCheck(), feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() pipeline.to("cuda") with autocast("cuda"): for i in range(args.batch_num): images = pipeline( [args.prompt] * args.batch_size, num_inference_steps=args.steps, guidance_scale=args.scale, generator=generator, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) if __name__ == "__main__": main()