diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index a2d0e9f..d7fea85 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None): | |||
36 | if n is None: | 36 | if n is None: |
37 | return append_zero(sigmas.flip(0)) | 37 | return append_zero(sigmas.flip(0)) |
38 | t_max = len(sigmas) - 1 # = 999 | 38 | t_max = len(sigmas) - 1 # = 999 |
39 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) |
40 | return append_zero(t_to_sigma(t, sigmas)) | 40 | return append_zero(t_to_sigma(t, sigmas)) |
41 | 41 | ||
42 | # from k_samplers utils.py | 42 | # from k_samplers utils.py |
@@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): | |||
91 | 91 | ||
92 | 92 | ||
93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): | 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
94 | sigma = sigma.to(Unet.device) | 94 | sigma = sigma.to(dtype=input.dtype, device=Unet.device) |
95 | DSsigmas = DSsigmas.to(Unet.device) | 95 | DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) |
96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
97 | # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") | ||
97 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), | 98 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
98 | encoder_hidden_states=kwargs['cond']).sample | 99 | encoder_hidden_states=kwargs['cond']).sample |
99 | return input + eps * c_out | 100 | return input + eps * c_out |
@@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
226 | d = to_d(latents, s, model_output) | 227 | d = to_d(latents, s, model_output) |
227 | dt = sigma_down - s | 228 | dt = sigma_down - s |
228 | latents = latents + d * dt | 229 | latents = latents + d * dt |
229 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 230 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, |
230 | generator=generator) * sigma_up | 231 | generator=generator) * sigma_up |
231 | 232 | ||
232 | return SchedulerOutput(prev_sample=latents) | 233 | return SchedulerOutput(prev_sample=latents) |