diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 24 |
1 files changed, 0 insertions, 24 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 1b1c9cf..a2d0e9f 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -191,27 +191,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
192 | self.timesteps = np.arange(0, self.num_inference_steps) | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
193 | 193 | ||
194 | def add_noise_to_input( | ||
195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
196 | ) -> Tuple[torch.FloatTensor, float]: | ||
197 | """ | ||
198 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
199 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
200 | |||
201 | TODO Args: | ||
202 | """ | ||
203 | if self.config.s_min <= sigma <= self.config.s_max: | ||
204 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
205 | else: | ||
206 | gamma = 0 | ||
207 | |||
208 | # sample eps ~ N(0, S_noise^2 * I) | ||
209 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
210 | sigma_hat = sigma + gamma * sigma | ||
211 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
212 | |||
213 | return sample_hat, sigma_hat | ||
214 | |||
215 | def step( | 194 | def step( |
216 | self, | 195 | self, |
217 | model_output: torch.FloatTensor, | 196 | model_output: torch.FloatTensor, |
@@ -219,9 +198,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
219 | timestep_prev: int, | 198 | timestep_prev: int, |
220 | sample: torch.FloatTensor, | 199 | sample: torch.FloatTensor, |
221 | generator: None, | 200 | generator: None, |
222 | # ,sigma_hat: float, | ||
223 | # sigma_prev: float, | ||
224 | # sample_hat: torch.FloatTensor, | ||
225 | return_dict: bool = True, | 201 | return_dict: bool = True, |
226 | ) -> Union[SchedulerOutput, Tuple]: | 202 | ) -> Union[SchedulerOutput, Tuple]: |
227 | """ | 203 | """ |