diff options
author | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
commit | 5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch) | |
tree | a3461a4f1a04fba52ec8fde8b7b07095c7422d85 /schedulers | |
parent | Added custom SD pipeline + euler_a scheduler (diff) | |
download | textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2 textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip |
Made inference script interactive
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 57a56de..29ebd07 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
216 | 216 | ||
217 | self.num_inference_steps = num_inference_steps | 217 | self.num_inference_steps = num_inference_steps |
218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) |
220 | self.timesteps = self.sigmas | 220 | self.timesteps = self.sigmas |
221 | 221 | ||
222 | def add_noise_to_input( | 222 | def add_noise_to_input( |
@@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
272 | """ | 272 | """ |
273 | latents = sample | 273 | latents = sample |
274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) |
275 | |||
276 | # if callback is not None: | ||
277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
278 | d = to_d(latents, timestep, model_output) | 275 | d = to_d(latents, timestep, model_output) |
279 | # Euler method | ||
280 | dt = sigma_down - timestep | 276 | dt = sigma_down - timestep |
281 | latents = latents + d * dt | 277 | latents = latents + d * dt |
282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 278 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |