From a8a5abae42f6f42056cc27e0cf5313aab080c3a7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 22:16:13 +0200 Subject: Various improvements, added inference script --- infer.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 infer.py (limited to 'infer.py') diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b9e9ff7 --- /dev/null +++ b/infer.py @@ -0,0 +1,121 @@ +import argparse +import datetime +from pathlib import Path +from torch import autocast +from diffusers import StableDiffusionPipeline +import torch +import json +from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler +from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from slugify import slugify +from pipelines.stable_diffusion.no_check import NoCheck + +model_id = "path-to-your-trained-model" + +prompt = "A photo of sks dog in a bucket" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--model", + type=str, + default=None, + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + ) + parser.add_argument( + "--batch_num", + type=int, + default=50, + ) + parser.add_argument( + "--steps", + type=int, + default=80, + ) + parser.add_argument( + "--scale", + type=int, + default=7.5, + ) + parser.add_argument( + "--seed", + type=int, + default=None, + ) + parser.add_argument( + "--output_dir", + type=str, + default="inference", + ) + parser.add_argument( + "--config", + type=str, + default=None, + ) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + return args + + +def main(): + args = parse_args() + + seed = args.seed or torch.random.seed() + generator = torch.Generator(device="cuda").manual_seed(seed) + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") + output_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) + text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) + vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) + unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) + feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) + + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=NoCheck(), + feature_extractor=feature_extractor + ) + pipeline.enable_attention_slicing() + pipeline.to("cuda") + + with autocast("cuda"): + for i in range(args.batch_num): + images = pipeline( + [args.prompt] * args.batch_size, + num_inference_steps=args.steps, + guidance_scale=args.scale, + generator=generator, + ).images + + for j, image in enumerate(images): + image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf