summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
commit64c594869135354a38353551bd58a93e15bd5b85 (patch)
tree2bcc085a396824f78e58c90b1f6e9553c7f5c8c1 /schedulers
parentFix img2img (diff)
downloadtextual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.gz
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.bz2
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.zip
Small performance improvements
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index a2d0e9f..d7fea85 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -36,7 +36,7 @@ def get_sigmas(sigmas, n=None):
36 if n is None: 36 if n is None:
37 return append_zero(sigmas.flip(0)) 37 return append_zero(sigmas.flip(0))
38 t_max = len(sigmas) - 1 # = 999 38 t_max = len(sigmas) - 1 # = 999
39 t = torch.linspace(t_max, 0, n, device=sigmas.device) 39 t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype)
40 return append_zero(t_to_sigma(t, sigmas)) 40 return append_zero(t_to_sigma(t, sigmas))
41 41
42# from k_samplers utils.py 42# from k_samplers utils.py
@@ -91,9 +91,10 @@ def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
91 91
92 92
93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): 93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
94 sigma = sigma.to(Unet.device) 94 sigma = sigma.to(dtype=input.dtype, device=Unet.device)
95 DSsigmas = DSsigmas.to(Unet.device) 95 DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device)
96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] 96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
97 # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}")
97 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), 98 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
98 encoder_hidden_states=kwargs['cond']).sample 99 encoder_hidden_states=kwargs['cond']).sample
99 return input + eps * c_out 100 return input + eps * c_out
@@ -226,7 +227,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
226 d = to_d(latents, s, model_output) 227 d = to_d(latents, s, model_output)
227 dt = sigma_down - s 228 dt = sigma_down - s
228 latents = latents + d * dt 229 latents = latents + d * dt
229 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 230 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
230 generator=generator) * sigma_up 231 generator=generator) * sigma_up
231 232
232 return SchedulerOutput(prev_sample=latents) 233 return SchedulerOutput(prev_sample=latents)