diff options
Diffstat (limited to 'schedulers')
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py new file mode 100644 index 0000000..57a56de --- /dev/null +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -0,0 +1,323 @@ | |||
| 1 | |||
| 2 | |||
| 3 | import math | ||
| 4 | import warnings | ||
| 5 | from typing import Optional, Tuple, Union | ||
| 6 | |||
| 7 | import numpy as np | ||
| 8 | import torch | ||
| 9 | |||
| 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| 11 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
| 12 | |||
| 13 | |||
| 14 | ''' | ||
| 15 | helper functions: append_zero(), | ||
| 16 | t_to_sigma(), | ||
| 17 | get_sigmas(), | ||
| 18 | append_dims(), | ||
| 19 | CFGDenoiserForward(), | ||
| 20 | get_scalings(), | ||
| 21 | DSsigma_to_t(), | ||
| 22 | DiscreteEpsDDPMDenoiserForward(), | ||
| 23 | to_d(), | ||
| 24 | get_ancestral_step() | ||
| 25 | need cleaning | ||
| 26 | ''' | ||
| 27 | |||
| 28 | |||
| 29 | def append_zero(x): | ||
| 30 | return torch.cat([x, x.new_zeros([1])]) | ||
| 31 | |||
| 32 | |||
| 33 | def t_to_sigma(t, sigmas): | ||
| 34 | t = t.float() | ||
| 35 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 36 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 37 | |||
| 38 | |||
| 39 | def get_sigmas(sigmas, n=None): | ||
| 40 | if n is None: | ||
| 41 | return append_zero(sigmas.flip(0)) | ||
| 42 | t_max = len(sigmas) - 1 # = 999 | ||
| 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 45 | return append_zero(t_to_sigma(t, sigmas)) | ||
| 46 | |||
| 47 | # from k_samplers utils.py | ||
| 48 | |||
| 49 | |||
| 50 | def append_dims(x, target_dims): | ||
| 51 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 52 | dims_to_append = target_dims - x.ndim | ||
| 53 | if dims_to_append < 0: | ||
| 54 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 55 | return x[(...,) + (None,) * dims_to_append] | ||
| 56 | |||
| 57 | |||
| 58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | ||
| 59 | # x_in = torch.cat([x] * 2)#A# concat the latent | ||
| 60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | ||
| 61 | # cond_in = torch.cat([uncond, cond]) | ||
| 62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | ||
| 63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | ||
| 64 | # return uncond + (cond - uncond) * cond_scale | ||
| 65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | ||
| 66 | return noise_pred | ||
| 67 | |||
| 68 | # from k_samplers sampling.py | ||
| 69 | |||
| 70 | |||
| 71 | def to_d(x, sigma, denoised): | ||
| 72 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 73 | return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) | ||
| 74 | |||
| 75 | |||
| 76 | def get_scalings(sigma): | ||
| 77 | sigma_data = 1. | ||
| 78 | c_out = -sigma | ||
| 79 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 80 | return c_out, c_in | ||
| 81 | |||
| 82 | # DiscreteSchedule DS | ||
| 83 | |||
| 84 | |||
| 85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | ||
| 86 | # quantize = self.quantize if quantize is None else quantize | ||
| 87 | quantize = False | ||
| 88 | dists = torch.abs(sigma - DSsigmas[:, None]) | ||
| 89 | if quantize: | ||
| 90 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 91 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 92 | low, high = DSsigmas[low_idx], DSsigmas[high_idx] | ||
| 93 | w = (low - sigma) / (low - high) | ||
| 94 | w = w.clamp(0, 1) | ||
| 95 | t = (1 - w) * low_idx + w * high_idx | ||
| 96 | return t.view(sigma.shape) | ||
| 97 | |||
| 98 | |||
| 99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | ||
| 100 | sigma = sigma.to(Unet.device) | ||
| 101 | DSsigmas = DSsigmas.to(Unet.device) | ||
| 102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | ||
| 103 | # ??? what is eps? | ||
| 104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
| 105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
| 106 | encoder_hidden_states=kwargs['cond']).sample | ||
| 107 | return input + eps * c_out | ||
| 108 | |||
| 109 | |||
| 110 | # from k_samplers sampling.py | ||
| 111 | def get_ancestral_step(sigma_from, sigma_to): | ||
| 112 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 113 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 114 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 115 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 116 | return sigma_down, sigma_up | ||
| 117 | |||
| 118 | |||
| 119 | ''' | ||
| 120 | Euler Ancestral Scheduler | ||
| 121 | ''' | ||
| 122 | |||
| 123 | |||
| 124 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | ||
| 125 | """ | ||
| 126 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | ||
| 127 | the VE column of Table 1 from [1] for reference. | ||
| 128 | |||
| 129 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." | ||
| 130 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic | ||
| 131 | differential equations." https://arxiv.org/abs/2011.13456 | ||
| 132 | |||
| 133 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
| 134 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
| 135 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
| 136 | [`~ConfigMixin.from_config`] functions. | ||
| 137 | |||
| 138 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of | ||
| 139 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the | ||
| 140 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. | ||
| 141 | |||
| 142 | Args: | ||
| 143 | sigma_min (`float`): minimum noise magnitude | ||
| 144 | sigma_max (`float`): maximum noise magnitude | ||
| 145 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
| 146 | A reasonable range is [1.000, 1.011]. | ||
| 147 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
| 148 | A reasonable range is [0, 100]. | ||
| 149 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). | ||
| 150 | A reasonable range is [0, 10]. | ||
| 151 | s_max (`float`): the end value of the sigma range where we add noise. | ||
| 152 | A reasonable range is [0.2, 80]. | ||
| 153 | |||
| 154 | """ | ||
| 155 | |||
| 156 | @register_to_config | ||
| 157 | def __init__( | ||
| 158 | self, | ||
| 159 | num_train_timesteps: int = 1000, | ||
| 160 | beta_start: float = 0.0001, | ||
| 161 | beta_end: float = 0.02, | ||
| 162 | beta_schedule: str = "linear", | ||
| 163 | trained_betas: Optional[np.ndarray] = None, | ||
| 164 | clip_sample: bool = True, | ||
| 165 | set_alpha_to_one: bool = True, | ||
| 166 | steps_offset: int = 0, | ||
| 167 | ): | ||
| 168 | if trained_betas is not None: | ||
| 169 | self.betas = torch.from_numpy(trained_betas) | ||
| 170 | if beta_schedule == "linear": | ||
| 171 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
| 172 | elif beta_schedule == "scaled_linear": | ||
| 173 | # this schedule is very specific to the latent diffusion model. | ||
| 174 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
| 175 | elif beta_schedule == "squaredcos_cap_v2": | ||
| 176 | # Glide cosine schedule | ||
| 177 | self.betas = betas_for_alpha_bar(num_train_timesteps) | ||
| 178 | else: | ||
| 179 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
| 180 | |||
| 181 | self.alphas = 1.0 - self.betas | ||
| 182 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
| 183 | |||
| 184 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
| 185 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
| 186 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
| 187 | # whether we use the final alpha of the "non-previous" one. | ||
| 188 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
| 189 | |||
| 190 | # setable values | ||
| 191 | self.num_inference_steps = None | ||
| 192 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | ||
| 193 | |||
| 194 | # A# take number of steps as input | ||
| 195 | # A# store 1) number of steps 2) timesteps 3) schedule | ||
| 196 | |||
| 197 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): | ||
| 198 | """ | ||
| 199 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
| 200 | |||
| 201 | Args: | ||
| 202 | num_inference_steps (`int`): | ||
| 203 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
| 204 | """ | ||
| 205 | |||
| 206 | # offset = self.config.steps_offset | ||
| 207 | |||
| 208 | # if "offset" in kwargs: | ||
| 209 | # warnings.warn( | ||
| 210 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
| 211 | # " Please pass `steps_offset` to `__init__` instead.", | ||
| 212 | # DeprecationWarning, | ||
| 213 | # ) | ||
| 214 | |||
| 215 | # offset = kwargs["offset"] | ||
| 216 | |||
| 217 | self.num_inference_steps = num_inference_steps | ||
| 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
| 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | ||
| 220 | self.timesteps = self.sigmas | ||
| 221 | |||
| 222 | def add_noise_to_input( | ||
| 223 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
| 224 | ) -> Tuple[torch.FloatTensor, float]: | ||
| 225 | """ | ||
| 226 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
| 227 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
| 228 | |||
| 229 | TODO Args: | ||
| 230 | """ | ||
| 231 | if self.config.s_min <= sigma <= self.config.s_max: | ||
| 232 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
| 233 | else: | ||
| 234 | gamma = 0 | ||
| 235 | |||
| 236 | # sample eps ~ N(0, S_noise^2 * I) | ||
| 237 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
| 238 | sigma_hat = sigma + gamma * sigma | ||
| 239 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
| 240 | |||
| 241 | return sample_hat, sigma_hat | ||
| 242 | |||
| 243 | def step( | ||
| 244 | self, | ||
| 245 | model_output: torch.FloatTensor, | ||
| 246 | timestep: torch.IntTensor, | ||
| 247 | timestep_prev: torch.IntTensor, | ||
| 248 | sample: torch.FloatTensor, | ||
| 249 | generator: None, | ||
| 250 | # ,sigma_hat: float, | ||
| 251 | # sigma_prev: float, | ||
| 252 | # sample_hat: torch.FloatTensor, | ||
| 253 | return_dict: bool = True, | ||
| 254 | ) -> Union[SchedulerOutput, Tuple]: | ||
| 255 | """ | ||
| 256 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
| 257 | process from the learned model outputs (most often the predicted noise). | ||
| 258 | |||
| 259 | Args: | ||
| 260 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| 261 | sigma_hat (`float`): TODO | ||
| 262 | sigma_prev (`float`): TODO | ||
| 263 | sample_hat (`torch.FloatTensor`): TODO | ||
| 264 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| 265 | |||
| 266 | EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). | ||
| 267 | Returns: | ||
| 268 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: | ||
| 269 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| 270 | returning a tuple, the first element is the sample tensor. | ||
| 271 | |||
| 272 | """ | ||
| 273 | latents = sample | ||
| 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | ||
| 275 | |||
| 276 | # if callback is not None: | ||
| 277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
| 278 | d = to_d(latents, timestep, model_output) | ||
| 279 | # Euler method | ||
| 280 | dt = sigma_down - timestep | ||
| 281 | latents = latents + d * dt | ||
| 282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | ||
| 283 | generator=generator) * sigma_up | ||
| 284 | return SchedulerOutput(prev_sample=latents) | ||
| 285 | |||
| 286 | def step_correct( | ||
| 287 | self, | ||
| 288 | model_output: torch.FloatTensor, | ||
| 289 | sigma_hat: float, | ||
| 290 | sigma_prev: float, | ||
| 291 | sample_hat: torch.FloatTensor, | ||
| 292 | sample_prev: torch.FloatTensor, | ||
| 293 | derivative: torch.FloatTensor, | ||
| 294 | generator: None, | ||
| 295 | return_dict: bool = True, | ||
| 296 | ) -> Union[SchedulerOutput, Tuple]: | ||
| 297 | """ | ||
| 298 | Correct the predicted sample based on the output model_output of the network. TODO complete description | ||
| 299 | |||
| 300 | Args: | ||
| 301 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| 302 | sigma_hat (`float`): TODO | ||
| 303 | sigma_prev (`float`): TODO | ||
| 304 | sample_hat (`torch.FloatTensor`): TODO | ||
| 305 | sample_prev (`torch.FloatTensor`): TODO | ||
| 306 | derivative (`torch.FloatTensor`): TODO | ||
| 307 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| 308 | |||
| 309 | Returns: | ||
| 310 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO | ||
| 311 | |||
| 312 | """ | ||
| 313 | pred_original_sample = sample_prev + sigma_prev * model_output | ||
| 314 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev | ||
| 315 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) | ||
| 316 | |||
| 317 | if not return_dict: | ||
| 318 | return (sample_prev, derivative) | ||
| 319 | |||
| 320 | return SchedulerOutput(prev_sample=sample_prev) | ||
| 321 | |||
| 322 | def add_noise(self, original_samples, noise, timesteps): | ||
| 323 | raise NotImplementedError() | ||
