From 64c594869135354a38353551bd58a93e15bd5b85 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 20:57:43 +0200 Subject: Small performance improvements --- schedulers/scheduling_euler_a.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index a2d0e9f..d7fea85 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None): if n is None: return append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 - t = torch.linspace(t_max, 0, n, device=sigmas.device) + t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) return append_zero(t_to_sigma(t, sigmas)) # from k_samplers utils.py @@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): - sigma = sigma.to(Unet.device) - DSsigmas = DSsigmas.to(Unet.device) + sigma = sigma.to(dtype=input.dtype, device=Unet.device) + DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] + # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample return input + eps * c_out @@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): d = to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt - latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, + latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, generator=generator) * sigma_up return SchedulerOutput(prev_sample=latents) -- cgit v1.2.3-54-g00ecf