summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
committerVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
commit5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch)
treea3461a4f1a04fba52ec8fde8b7b07095c7422d85 /schedulers
parentAdded custom SD pipeline + euler_a scheduler (diff)
downloadtextual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip
Made inference script interactive
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 57a56de..29ebd07 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
216 216
217 self.num_inference_steps = num_inference_steps 217 self.num_inference_steps = num_inference_steps
218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device)
220 self.timesteps = self.sigmas 220 self.timesteps = self.sigmas
221 221
222 def add_noise_to_input( 222 def add_noise_to_input(
@@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
272 """ 272 """
273 latents = sample 273 latents = sample
274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) 274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
275
276 # if callback is not None:
277 # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
278 d = to_d(latents, timestep, model_output) 275 d = to_d(latents, timestep, model_output)
279 # Euler method
280 dt = sigma_down - timestep 276 dt = sigma_down - timestep
281 latents = latents + d * dt 277 latents = latents + d * dt
282 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 278 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,