diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-02 15:14:29 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-02 15:14:29 +0200 |
| commit | 13b0d9f763269df405d1aeba86213f1c7ce4e7ca (patch) | |
| tree | b4b2761032e2ba715dac0cf50adee9ff911d73f6 /schedulers | |
| parent | WIP: img2img (diff) | |
| download | textual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.tar.gz textual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.tar.bz2 textual-inversion-diff-13b0d9f763269df405d1aeba86213f1c7ce4e7ca.zip | |
More consistent euler_a
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 59 |
1 files changed, 26 insertions, 33 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -1,7 +1,3 @@ | |||
| 1 | |||
| 2 | |||
| 3 | import math | ||
| 4 | import warnings | ||
| 5 | from typing import Optional, Tuple, Union | 1 | from typing import Optional, Tuple, Union |
| 6 | 2 | ||
| 7 | import numpy as np | 3 | import numpy as np |
| @@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 157 | beta_end: float = 0.02, | 153 | beta_end: float = 0.02, |
| 158 | beta_schedule: str = "linear", | 154 | beta_schedule: str = "linear", |
| 159 | trained_betas: Optional[np.ndarray] = None, | 155 | trained_betas: Optional[np.ndarray] = None, |
| 160 | clip_sample: bool = True, | ||
| 161 | set_alpha_to_one: bool = True, | ||
| 162 | steps_offset: int = 0, | ||
| 163 | ): | 156 | ): |
| 164 | if trained_betas is not None: | 157 | if trained_betas is not None: |
| 165 | self.betas = torch.from_numpy(trained_betas) | 158 | self.betas = torch.from_numpy(trained_betas) |
| @@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 177 | self.alphas = 1.0 - self.betas | 170 | self.alphas = 1.0 - self.betas |
| 178 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 171 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| 179 | 172 | ||
| 180 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
| 181 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
| 182 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
| 183 | # whether we use the final alpha of the "non-previous" one. | ||
| 184 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
| 185 | |||
| 186 | # setable values | 173 | # setable values |
| 187 | self.num_inference_steps = None | 174 | self.num_inference_steps = None |
| 188 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 175 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] |
| @@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 199 | the number of diffusion steps used when generating samples with a pre-trained model. | 186 | the number of diffusion steps used when generating samples with a pre-trained model. |
| 200 | """ | 187 | """ |
| 201 | 188 | ||
| 202 | # offset = self.config.steps_offset | ||
| 203 | |||
| 204 | # if "offset" in kwargs: | ||
| 205 | # warnings.warn( | ||
| 206 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
| 207 | # " Please pass `steps_offset` to `__init__` instead.", | ||
| 208 | # DeprecationWarning, | ||
| 209 | # ) | ||
| 210 | |||
| 211 | # offset = kwargs["offset"] | ||
| 212 | |||
| 213 | self.num_inference_steps = num_inference_steps | 189 | self.num_inference_steps = num_inference_steps |
| 214 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 190 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 215 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) |
| 216 | self.timesteps = self.sigmas | 192 | self.timesteps = np.arange(0, self.num_inference_steps) |
| 217 | 193 | ||
| 218 | def add_noise_to_input( | 194 | def add_noise_to_input( |
| 219 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | 195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None |
| @@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 239 | def step( | 215 | def step( |
| 240 | self, | 216 | self, |
| 241 | model_output: torch.FloatTensor, | 217 | model_output: torch.FloatTensor, |
| 242 | timestep: torch.IntTensor, | 218 | timestep: int, |
| 243 | timestep_prev: torch.IntTensor, | 219 | timestep_prev: int, |
| 244 | sample: torch.FloatTensor, | 220 | sample: torch.FloatTensor, |
| 245 | generator: None, | 221 | generator: None, |
| 246 | # ,sigma_hat: float, | 222 | # ,sigma_hat: float, |
| @@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 266 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. |
| 267 | 243 | ||
| 268 | """ | 244 | """ |
| 245 | s = self.sigmas[timestep] | ||
| 246 | s_prev = self.sigmas[timestep_prev] | ||
| 269 | latents = sample | 247 | latents = sample |
| 270 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 248 | |
| 271 | d = to_d(latents, timestep, model_output) | 249 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) |
| 272 | dt = sigma_down - timestep | 250 | d = to_d(latents, s, model_output) |
| 251 | dt = sigma_down - s | ||
| 273 | latents = latents + d * dt | 252 | latents = latents + d * dt |
| 274 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 253 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |
| 275 | generator=generator) * sigma_up | 254 | generator=generator) * sigma_up |
| 255 | |||
| 276 | return SchedulerOutput(prev_sample=latents) | 256 | return SchedulerOutput(prev_sample=latents) |
| 277 | 257 | ||
| 278 | def step_correct( | 258 | def step_correct( |
| @@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 311 | 291 | ||
| 312 | return SchedulerOutput(prev_sample=sample_prev) | 292 | return SchedulerOutput(prev_sample=sample_prev) |
| 313 | 293 | ||
| 314 | def add_noise(self, original_samples, noise, timesteps): | 294 | def add_noise( |
| 315 | raise NotImplementedError() | 295 | self, |
| 296 | original_samples: torch.FloatTensor, | ||
| 297 | noise: torch.FloatTensor, | ||
| 298 | timesteps: torch.IntTensor, | ||
| 299 | ) -> torch.FloatTensor: | ||
| 300 | sigmas = self.sigmas.to(original_samples.device) | ||
| 301 | timesteps = timesteps.to(original_samples.device) | ||
| 302 | |||
| 303 | sigma = sigmas[timesteps].flatten() | ||
| 304 | while len(sigma.shape) < len(original_samples.shape): | ||
| 305 | sigma = sigma.unsqueeze(-1) | ||
| 306 | |||
| 307 | noisy_samples = original_samples + noise * sigma | ||
| 308 | return noisy_samples | ||
