From 49de8142f523aef3f6adfd0c33a9a160aa7400c0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 12:56:58 +0200 Subject: WIP: img2img --- infer.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 67 insertions(+), 23 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index d917239..b440cb6 100644 --- a/infer.py +++ b/infer.py @@ -8,13 +8,47 @@ from pathlib import Path from torch import autocast import torch import json +from PIL import Image from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify -from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from schedulers.scheduling_euler_a import EulerAScheduler +default_args = { + "model": None, + "scheduler": "euler_a", + "output_dir": "output/inference", + "config": None, +} + + +default_cmds = { + "prompt": None, + "negative_prompt": None, + "image": None, + "image_strength": .7, + "width": 512, + "height": 512, + "batch_size": 1, + "batch_num": 1, + "steps": 50, + "guidance_scale": 7.0, + "seed": None, + "config": None, +} + + +def merge_dicts(d1, *args): + d1 = d1.copy() + + for d in args: + d1.update({k: v for (k, v) in d.items() if v is not None}) + + return d1 + + def create_args_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." @@ -22,23 +56,19 @@ def create_args_parser(): parser.add_argument( "--model", type=str, - default=None, ) parser.add_argument( "--scheduler", type=str, choices=["plms", "ddim", "klms", "euler_a"], - default="euler_a", ) parser.add_argument( "--output_dir", type=str, - default="output/inference", ) parser.add_argument( "--config", type=str, - default=None, ) return parser @@ -51,66 +81,69 @@ def create_cmd_parser(): parser.add_argument( "--prompt", type=str, - default=None, ) parser.add_argument( "--negative_prompt", type=str, - default=None, + ) + parser.add_argument( + "--image", + type=str, + ) + parser.add_argument( + "--image_strength", + type=float, ) parser.add_argument( "--width", type=int, - default=512, ) parser.add_argument( "--height", type=int, - default=512, ) parser.add_argument( "--batch_size", type=int, - default=1, ) parser.add_argument( "--batch_num", type=int, - default=1, ) parser.add_argument( "--steps", type=int, - default=70, ) parser.add_argument( "--guidance_scale", - type=int, - default=7, + type=float, ) parser.add_argument( "--seed", type=int, - default=None, ) parser.add_argument( "--config", type=str, - default=None, ) return parser -def run_parser(parser, input=None): +def run_parser(parser, defaults, input=None): args = parser.parse_known_args(input)[0] + conf_args = argparse.Namespace() if args.config is not None: with open(args.config, 'rt') as f: - args = parser.parse_known_args( + conf_args = parser.parse_known_args( namespace=argparse.Namespace(**json.load(f)["args"]))[0] - return args + res = defaults.copy() + for dict in [vars(conf_args), vars(args)]: + res.update({k: v for (k, v) in dict.items() if v is not None}) + + return argparse.Namespace(**res) def save_args(basepath, args, extra={}): @@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) - pipeline = CLIPGuidedStableDiffusion( + pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, @@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): scheduler=scheduler, feature_extractor=feature_extractor ) - pipeline.enable_attention_slicing() + # pipeline.enable_attention_slicing() pipeline.to("cuda") print("Pipeline loaded.") @@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args): save_args(output_dir, args) + if args.image: + init_image = Image.open(args.image) + if not init_image.mode == "RGB": + init_image = init_image.convert("RGB") + else: + init_image = None + with autocast("cuda"): for i in range(args.batch_num): pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") @@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args): num_inference_steps=args.steps, guidance_scale=args.guidance_scale, generator=generator, + latents=init_image, ).images for j, image in enumerate(images): image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + class CmdParse(cmd.Cmd): prompt = 'dream> ' @@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd): return True try: - args = run_parser(self.parser, elements) + args = run_parser(self.parser, default_cmds, elements) except SystemExit: self.parser.print_help() @@ -233,7 +277,7 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.WARN) args_parser = create_args_parser() - args = run_parser(args_parser) + args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) pipeline = create_pipeline(args.model, args.scheduler) -- cgit v1.2.3-54-g00ecf