From 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Sep 2022 14:13:51 +0200 Subject: Added custom SD pipeline + euler_a scheduler --- infer.py | 111 +++++++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 80 insertions(+), 31 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index f2007e9..de3d792 100644 --- a/infer.py +++ b/infer.py @@ -1,18 +1,15 @@ import argparse import datetime +import logging from pathlib import Path from torch import autocast -from diffusers import StableDiffusionPipeline import torch import json -from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler -from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler +from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify -from pipelines.stable_diffusion.no_check import NoCheck - -model_id = "path-to-your-trained-model" - -prompt = "A photo of sks dog in a bucket" +from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler def parse_args(): @@ -29,6 +26,21 @@ def parse_args(): type=str, default=None, ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + ) + parser.add_argument( + "--width", + type=int, + default=512, + ) + parser.add_argument( + "--height", + type=int, + default=512, + ) parser.add_argument( "--batch_size", type=int, @@ -42,17 +54,28 @@ def parse_args(): parser.add_argument( "--steps", type=int, - default=80, + default=120, ) parser.add_argument( - "--scale", + "--scheduler", + type=str, + choices=["plms", "ddim", "klms", "euler_a"], + default="euler_a", + ) + parser.add_argument( + "--guidance_scale", type=int, default=7.5, ) + parser.add_argument( + "--clip_guidance_scale", + type=int, + default=100, + ) parser.add_argument( "--seed", type=int, - default=None, + default=torch.random.seed(), ) parser.add_argument( "--output_dir", @@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def main(): - args = parse_args() - - seed = args.seed or torch.random.seed() - - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") - output_dir.mkdir(parents=True, exist_ok=True) - save_args(output_dir, args) - +def gen(args, output_dir): tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) + clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) - feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) + feature_extractor = CLIPFeatureExtractor.from_pretrained( + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) - pipeline = StableDiffusionPipeline( + if args.scheduler == "plms": + scheduler = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + elif args.scheduler == "klms": + scheduler = LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + elif args.scheduler == "ddim": + scheduler = DDIMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False + ) + else: + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False + ) + + pipeline = CLIPGuidedStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=NoCheck(), + clip_model=clip_model, + scheduler=scheduler, feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() @@ -113,16 +144,34 @@ def main(): with autocast("cuda"): for i in range(args.batch_num): - generator = torch.Generator(device="cuda").manual_seed(seed + i) + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( - [args.prompt] * args.batch_size, + prompt=[args.prompt] * args.batch_size, + height=args.height, + width=args.width, + negative_prompt=args.negative_prompt, num_inference_steps=args.steps, - guidance_scale=args.scale, + guidance_scale=args.guidance_scale, + clip_guidance_scale=args.clip_guidance_scale, generator=generator, ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") + output_dir.mkdir(parents=True, exist_ok=True) + + save_args(output_dir, args) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + gen(args, output_dir) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf