From b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 16:57:29 +0100 Subject: Update --- infer.py | 52 +++++++++++++++++++++++----------------------------- 1 file changed, 23 insertions(+), 29 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 2bf9cb3..ab5f247 100644 --- a/infer.py +++ b/infer.py @@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, - "scheduler": "dpmpp", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", @@ -29,6 +28,7 @@ default_args = { default_cmds = { + "scheduler": "dpmpp", "prompt": None, "negative_prompt": None, "image": None, @@ -61,11 +61,6 @@ def create_args_parser(): "--model", type=str, ) - parser.add_argument( - "--scheduler", - type=str, - choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], - ) parser.add_argument( "--precision", type=str, @@ -91,6 +86,11 @@ def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) + parser.add_argument( + "--scheduler", + type=str, + choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], + ) parser.add_argument( "--prompt", type=str, @@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): print(f"Loaded {placeholder_token}") -def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): +def create_pipeline(model, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) + scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) - if scheduler == "plms": - scheduler = PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) - elif scheduler == "klms": - scheduler = LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - elif scheduler == "ddim": - scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False - ) - elif scheduler == "dpmpp": - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - else: - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, @@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args): else: init_image = None + if args.scheduler == "plms": + pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "klms": + pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "ddim": + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "dpmpp": + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "euler_a": + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( @@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd): generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print('Generation cancelled.') + except Exception as e: + print(e) + return def do_exit(self, line): return True @@ -345,7 +339,7 @@ def main(): output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) + pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() -- cgit v1.2.3-54-g00ecf