summaryrefslogtreecommitdiffstats
path: root/schedulers
diff options
context:
space:
mode:
Diffstat (limited to 'schedulers')
-rw-r--r--schedulers/scheduling_euler_a.py24
1 files changed, 0 insertions, 24 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 1b1c9cf..a2d0e9f 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -191,27 +191,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
192 self.timesteps = np.arange(0, self.num_inference_steps) 192 self.timesteps = np.arange(0, self.num_inference_steps)
193 193
194 def add_noise_to_input(
195 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
196 ) -> Tuple[torch.FloatTensor, float]:
197 """
198 Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
199 higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
200
201 TODO Args:
202 """
203 if self.config.s_min <= sigma <= self.config.s_max:
204 gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
205 else:
206 gamma = 0
207
208 # sample eps ~ N(0, S_noise^2 * I)
209 eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
210 sigma_hat = sigma + gamma * sigma
211 sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
212
213 return sample_hat, sigma_hat
214
215 def step( 194 def step(
216 self, 195 self,
217 model_output: torch.FloatTensor, 196 model_output: torch.FloatTensor,
@@ -219,9 +198,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
219 timestep_prev: int, 198 timestep_prev: int,
220 sample: torch.FloatTensor, 199 sample: torch.FloatTensor,
221 generator: None, 200 generator: None,
222 # ,sigma_hat: float,
223 # sigma_prev: float,
224 # sample_hat: torch.FloatTensor,
225 return_dict: bool = True, 201 return_dict: bool = True,
226 ) -> Union[SchedulerOutput, Tuple]: 202 ) -> Union[SchedulerOutput, Tuple]:
227 """ 203 """