summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 57a56de..29ebd07 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
216 216
217 self.num_inference_steps = num_inference_steps 217 self.num_inference_steps = num_inference_steps
218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device)
220 self.timesteps = self.sigmas 220 self.timesteps = self.sigmas
221 221
222 def add_noise_to_input( 222 def add_noise_to_input(
@@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
272 """ 272 """
273 latents = sample 273 latents = sample
274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) 274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
275
276 # if callback is not None:
277 # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
278 d = to_d(latents, timestep, model_output) 275 d = to_d(latents, timestep, model_output)
279 # Euler method
280 dt = sigma_down - timestep 276 dt = sigma_down - timestep
281 latents = latents + d * dt 277 latents = latents + d * dt
282 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 278 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,