From 64c594869135354a38353551bd58a93e15bd5b85 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 20:57:43 +0200 Subject: Small performance improvements --- infer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index f2c380f..b15b17f 100644 --- a/infer.py +++ b/infer.py @@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler default_args = { "model": None, "scheduler": "euler_a", + "precision": "bf16", "output_dir": "output/inference", "config": None, } @@ -28,7 +29,7 @@ default_cmds = { "prompt": None, "negative_prompt": None, "image": None, - "image_strength": .3, + "image_noise": .7, "width": 512, "height": 512, "batch_size": 1, @@ -62,6 +63,11 @@ def create_args_parser(): type=str, choices=["plms", "ddim", "klms", "euler_a"], ) + parser.add_argument( + "--precision", + type=str, + choices=["fp32", "fp16", "bf16"], + ) parser.add_argument( "--output_dir", type=str, @@ -91,7 +97,7 @@ def create_cmd_parser(): type=str, ) parser.add_argument( - "--image_strength", + "--image_noise", type=float, ) parser.add_argument( @@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def create_pipeline(model, scheduler, dtype=torch.bfloat16): +def create_pipeline(model, scheduler, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) @@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): guidance_scale=args.guidance_scale, generator=generator, latents=init_image, - strength=args.image_strength, + strength=args.image_noise, ).images for j, image in enumerate(images): @@ -279,9 +285,11 @@ def main(): args_parser = create_args_parser() args = run_parser(args_parser, default_args) + output_dir = Path(args.output_dir) + dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler) + pipeline = create_pipeline(args.model, args.scheduler, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() -- cgit v1.2.3-54-g00ecf