From 64c594869135354a38353551bd58a93e15bd5b85 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 20:57:43 +0200 Subject: Small performance improvements --- infer.py | 18 +++++++++++++----- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 18 ++++++++++-------- schedulers/scheduling_euler_a.py | 9 +++++---- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/infer.py b/infer.py index f2c380f..b15b17f 100644 --- a/infer.py +++ b/infer.py @@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler default_args = { "model": None, "scheduler": "euler_a", + "precision": "bf16", "output_dir": "output/inference", "config": None, } @@ -28,7 +29,7 @@ default_cmds = { "prompt": None, "negative_prompt": None, "image": None, - "image_strength": .3, + "image_noise": .7, "width": 512, "height": 512, "batch_size": 1, @@ -62,6 +63,11 @@ def create_args_parser(): type=str, choices=["plms", "ddim", "klms", "euler_a"], ) + parser.add_argument( + "--precision", + type=str, + choices=["fp32", "fp16", "bf16"], + ) parser.add_argument( "--output_dir", type=str, @@ -91,7 +97,7 @@ def create_cmd_parser(): type=str, ) parser.add_argument( - "--image_strength", + "--image_noise", type=float, ) parser.add_argument( @@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def create_pipeline(model, scheduler, dtype=torch.bfloat16): +def create_pipeline(model, scheduler, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) @@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): guidance_scale=args.guidance_scale, generator=generator, latents=init_image, - strength=args.image_strength, + strength=args.image_noise, ).images for j, image in enumerate(images): @@ -279,9 +285,11 @@ def main(): args_parser = create_args_parser() args = run_parser(args_parser, default_args) + output_dir = Path(args.output_dir) + dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler) + pipeline = create_pipeline(args.model, args.scheduler, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b4c85e9..8fbe5f9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline): # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device + latents_dtype = text_embeddings.dtype latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - dtype=text_embeddings.dtype, - ) + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) elif isinstance(latents, PIL.Image.Image): latents = preprocess(latents, width, height) latent_dist = self.vae.encode(latents.to(self.device)).latent_dist @@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline): else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + if latents.device != self.device: + raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if ensure_sigma: diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index a2d0e9f..d7fea85 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None): if n is None: return append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 - t = torch.linspace(t_max, 0, n, device=sigmas.device) + t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) return append_zero(t_to_sigma(t, sigmas)) # from k_samplers utils.py @@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): - sigma = sigma.to(Unet.device) - DSsigmas = DSsigmas.to(Unet.device) + sigma = sigma.to(dtype=input.dtype, device=Unet.device) + DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] + # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample return input + eps * c_out @@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): d = to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt - latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, + latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, generator=generator) * sigma_up return SchedulerOutput(prev_sample=latents) -- cgit v1.2.3-54-g00ecf