diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 45 |
1 files changed, 38 insertions, 7 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index c6436d8..13ea6b3 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -171,6 +171,9 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 171 | self.alphas = 1.0 - self.betas | 171 | self.alphas = 1.0 - self.betas |
| 172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| 173 | 173 | ||
| 174 | # standard deviation of the initial noise distribution | ||
| 175 | self.init_noise_sigma = 1.0 | ||
| 176 | |||
| 174 | # setable values | 177 | # setable values |
| 175 | self.num_inference_steps = None | 178 | self.num_inference_steps = None |
| 176 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 179 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] |
| @@ -190,13 +193,33 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 190 | self.num_inference_steps = num_inference_steps | 193 | self.num_inference_steps = num_inference_steps |
| 191 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 194 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 192 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 195 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
| 193 | self.timesteps = np.arange(0, self.num_inference_steps) | 196 | self.timesteps = self.sigmas[:-1] |
| 197 | self.is_scale_input_called = False | ||
| 198 | |||
| 199 | def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: | ||
| 200 | """ | ||
| 201 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | ||
| 202 | current timestep. | ||
| 203 | Args: | ||
| 204 | sample (`torch.FloatTensor`): input sample | ||
| 205 | timestep (`int`, optional): current timestep | ||
| 206 | Returns: | ||
| 207 | `torch.FloatTensor`: scaled input sample | ||
| 208 | """ | ||
| 209 | if isinstance(timestep, torch.Tensor): | ||
| 210 | timestep = timestep.to(self.timesteps.device) | ||
| 211 | if self.is_scale_input_called: | ||
| 212 | return sample | ||
| 213 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 214 | sigma = self.sigmas[step_index] | ||
| 215 | sample = sample * sigma | ||
| 216 | self.is_scale_input_called = True | ||
| 217 | return sample | ||
| 194 | 218 | ||
| 195 | def step( | 219 | def step( |
| 196 | self, | 220 | self, |
| 197 | model_output: torch.FloatTensor, | 221 | model_output: torch.FloatTensor, |
| 198 | timestep: int, | 222 | timestep: Union[float, torch.FloatTensor], |
| 199 | timestep_prev: int, | ||
| 200 | sample: torch.FloatTensor, | 223 | sample: torch.FloatTensor, |
| 201 | generator: torch.Generator = None, | 224 | generator: torch.Generator = None, |
| 202 | return_dict: bool = True, | 225 | return_dict: bool = True, |
| @@ -219,8 +242,13 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 219 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. |
| 220 | 243 | ||
| 221 | """ | 244 | """ |
| 222 | s = self.sigmas[timestep] | 245 | if isinstance(timestep, torch.Tensor): |
| 223 | s_prev = self.sigmas[timestep_prev] | 246 | timestep = timestep.to(self.timesteps.device) |
| 247 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 248 | step_prev_index = step_index + 1 | ||
| 249 | |||
| 250 | s = self.sigmas[step_index] | ||
| 251 | s_prev = self.sigmas[step_prev_index] | ||
| 224 | latents = sample | 252 | latents = sample |
| 225 | 253 | ||
| 226 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) | 254 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) |
| @@ -271,14 +299,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 271 | self, | 299 | self, |
| 272 | original_samples: torch.FloatTensor, | 300 | original_samples: torch.FloatTensor, |
| 273 | noise: torch.FloatTensor, | 301 | noise: torch.FloatTensor, |
| 274 | timesteps: torch.IntTensor, | 302 | timesteps: torch.FloatTensor, |
| 275 | ) -> torch.FloatTensor: | 303 | ) -> torch.FloatTensor: |
| 276 | sigmas = self.sigmas.to(original_samples.device) | 304 | sigmas = self.sigmas.to(original_samples.device) |
| 305 | schedule_timesteps = self.timesteps.to(original_samples.device) | ||
| 277 | timesteps = timesteps.to(original_samples.device) | 306 | timesteps = timesteps.to(original_samples.device) |
| 307 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
| 278 | 308 | ||
| 279 | sigma = sigmas[timesteps].flatten() | 309 | sigma = sigmas[step_indices].flatten() |
| 280 | while len(sigma.shape) < len(original_samples.shape): | 310 | while len(sigma.shape) < len(original_samples.shape): |
| 281 | sigma = sigma.unsqueeze(-1) | 311 | sigma = sigma.unsqueeze(-1) |
| 282 | 312 | ||
| 283 | noisy_samples = original_samples + noise * sigma | 313 | noisy_samples = original_samples + noise * sigma |
| 314 | self.is_scale_input_called = True | ||
| 284 | return noisy_samples | 315 | return noisy_samples |
