diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 57a56de..29ebd07 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 216 | 216 | ||
| 217 | self.num_inference_steps = num_inference_steps | 217 | self.num_inference_steps = num_inference_steps |
| 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) |
| 220 | self.timesteps = self.sigmas | 220 | self.timesteps = self.sigmas |
| 221 | 221 | ||
| 222 | def add_noise_to_input( | 222 | def add_noise_to_input( |
| @@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 272 | """ | 272 | """ |
| 273 | latents = sample | 273 | latents = sample |
| 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) |
| 275 | |||
| 276 | # if callback is not None: | ||
| 277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
| 278 | d = to_d(latents, timestep, model_output) | 275 | d = to_d(latents, timestep, model_output) |
| 279 | # Euler method | ||
| 280 | dt = sigma_down - timestep | 276 | dt = sigma_down - timestep |
| 281 | latents = latents + d * dt | 277 | latents = latents + d * dt |
| 282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 278 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |
