diff options
Diffstat (limited to 'schedulers')
-rw-r--r-- | schedulers/scheduling_euler_a.py | 59 |
1 files changed, 26 insertions, 33 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -1,7 +1,3 @@ | |||
1 | |||
2 | |||
3 | import math | ||
4 | import warnings | ||
5 | from typing import Optional, Tuple, Union | 1 | from typing import Optional, Tuple, Union |
6 | 2 | ||
7 | import numpy as np | 3 | import numpy as np |
@@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
157 | beta_end: float = 0.02, | 153 | beta_end: float = 0.02, |
158 | beta_schedule: str = "linear", | 154 | beta_schedule: str = "linear", |
159 | trained_betas: Optional[np.ndarray] = None, | 155 | trained_betas: Optional[np.ndarray] = None, |
160 | clip_sample: bool = True, | ||
161 | set_alpha_to_one: bool = True, | ||
162 | steps_offset: int = 0, | ||
163 | ): | 156 | ): |
164 | if trained_betas is not None: | 157 | if trained_betas is not None: |
165 | self.betas = torch.from_numpy(trained_betas) | 158 | self.betas = torch.from_numpy(trained_betas) |
@@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
177 | self.alphas = 1.0 - self.betas | 170 | self.alphas = 1.0 - self.betas |
178 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 171 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
179 | 172 | ||
180 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
181 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
182 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
183 | # whether we use the final alpha of the "non-previous" one. | ||
184 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
185 | |||
186 | # setable values | 173 | # setable values |
187 | self.num_inference_steps = None | 174 | self.num_inference_steps = None |
188 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 175 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] |
@@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
199 | the number of diffusion steps used when generating samples with a pre-trained model. | 186 | the number of diffusion steps used when generating samples with a pre-trained model. |
200 | """ | 187 | """ |
201 | 188 | ||
202 | # offset = self.config.steps_offset | ||
203 | |||
204 | # if "offset" in kwargs: | ||
205 | # warnings.warn( | ||
206 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
207 | # " Please pass `steps_offset` to `__init__` instead.", | ||
208 | # DeprecationWarning, | ||
209 | # ) | ||
210 | |||
211 | # offset = kwargs["offset"] | ||
212 | |||
213 | self.num_inference_steps = num_inference_steps | 189 | self.num_inference_steps = num_inference_steps |
214 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 190 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
215 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
216 | self.timesteps = self.sigmas | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
217 | 193 | ||
218 | def add_noise_to_input( | 194 | def add_noise_to_input( |
219 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | 195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None |
@@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
239 | def step( | 215 | def step( |
240 | self, | 216 | self, |
241 | model_output: torch.FloatTensor, | 217 | model_output: torch.FloatTensor, |
242 | timestep: torch.IntTensor, | 218 | timestep: int, |
243 | timestep_prev: torch.IntTensor, | 219 | timestep_prev: int, |
244 | sample: torch.FloatTensor, | 220 | sample: torch.FloatTensor, |
245 | generator: None, | 221 | generator: None, |
246 | # ,sigma_hat: float, | 222 | # ,sigma_hat: float, |
@@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
266 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. |
267 | 243 | ||
268 | """ | 244 | """ |
245 | s = self.sigmas[timestep] | ||
246 | s_prev = self.sigmas[timestep_prev] | ||
269 | latents = sample | 247 | latents = sample |
270 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 248 | |
271 | d = to_d(latents, timestep, model_output) | 249 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) |
272 | dt = sigma_down - timestep | 250 | d = to_d(latents, s, model_output) |
251 | dt = sigma_down - s | ||
273 | latents = latents + d * dt | 252 | latents = latents + d * dt |
274 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 253 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |
275 | generator=generator) * sigma_up | 254 | generator=generator) * sigma_up |
255 | |||
276 | return SchedulerOutput(prev_sample=latents) | 256 | return SchedulerOutput(prev_sample=latents) |
277 | 257 | ||
278 | def step_correct( | 258 | def step_correct( |
@@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
311 | 291 | ||
312 | return SchedulerOutput(prev_sample=sample_prev) | 292 | return SchedulerOutput(prev_sample=sample_prev) |
313 | 293 | ||
314 | def add_noise(self, original_samples, noise, timesteps): | 294 | def add_noise( |
315 | raise NotImplementedError() | 295 | self, |
296 | original_samples: torch.FloatTensor, | ||
297 | noise: torch.FloatTensor, | ||
298 | timesteps: torch.IntTensor, | ||
299 | ) -> torch.FloatTensor: | ||
300 | sigmas = self.sigmas.to(original_samples.device) | ||
301 | timesteps = timesteps.to(original_samples.device) | ||
302 | |||
303 | sigma = sigmas[timesteps].flatten() | ||
304 | while len(sigma.shape) < len(original_samples.shape): | ||
305 | sigma = sigma.unsqueeze(-1) | ||
306 | |||
307 | noisy_samples = original_samples + noise * sigma | ||
308 | return noisy_samples | ||