From 49463992f48ec25f2ea31b220a6cedac3466467a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 26 Oct 2022 11:11:33 +0200 Subject: New Euler_a scheduler --- schedulers/scheduling_euler_a.py | 286 ---------------------- schedulers/scheduling_euler_ancestral_discrete.py | 192 +++++++++++++++ 2 files changed, 192 insertions(+), 286 deletions(-) delete mode 100644 schedulers/scheduling_euler_a.py create mode 100644 schedulers/scheduling_euler_ancestral_discrete.py (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py deleted file mode 100644 index c097a8a..0000000 --- a/schedulers/scheduling_euler_a.py +++ /dev/null @@ -1,286 +0,0 @@ -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput - - -class EulerAScheduler(SchedulerMixin, ConfigMixin): - """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. - - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - num_inference_steps=None, - device='cuda' - ): - if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas).to(device) - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, - dtype=torch.float32, device=device) ** 2 - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.device = device - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = num_inference_steps - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - # get sigmas - self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) - - # A# take number of steps as input - # A# store 1) number of steps 2) timesteps 3) schedule - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - self.num_inference_steps = num_inference_steps - self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) - self.timesteps = self.sigmas[:-1] - self.is_scale_input_called = False - - def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - if self.is_scale_input_called: - return sample - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - sample = sample * sigma - self.is_scale_input_called = True - return sample - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: torch.Generator = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). - Returns: - [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: - [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - step_prev_index = step_index + 1 - - s = self.sigmas[step_index] - s_prev = self.sigmas[step_prev_index] - latents = sample - - sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) - d = self.to_d(latents, s, model_output) - dt = sigma_down - s - latents = latents + d * dt - latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, - generator=generator) * sigma_up - - return SchedulerOutput(prev_sample=latents) - - def step_correct( - self, - model_output: torch.FloatTensor, - sigma_hat: float, - sigma_prev: float, - sample_hat: torch.FloatTensor, - sample_prev: torch.FloatTensor, - derivative: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. TODO complete description - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - sample_prev (`torch.FloatTensor`): TODO - derivative (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO - - """ - pred_original_sample = sample_prev + sigma_prev * model_output - derivative_corr = (sample_prev - pred_original_sample) / sigma_prev - sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - - if not return_dict: - return (sample_prev, derivative) - - return SchedulerOutput(prev_sample=sample_prev) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - sigmas = self.sigmas.to(original_samples.device) - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - self.is_scale_input_called = True - return noisy_samples - - # from k_samplers sampling.py - - def get_ancestral_step(self, sigma_from, sigma_to): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - return sigma_down, sigma_up - - def t_to_sigma(self, t, sigmas): - t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() - return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] - - def append_zero(self, x): - return torch.cat([x, x.new_zeros([1])]) - - def get_sigmas(self, sigmas, n=None): - if n is None: - return self.append_zero(sigmas.flip(0)) - t_max = len(sigmas) - 1 # = 999 - device = self.device - t = torch.linspace(t_max, 0, n, device=device) - # t = torch.linspace(t_max, 0, n, device=sigmas.device) - return self.append_zero(self.t_to_sigma(t, sigmas)) - - # from k_samplers utils.py - def append_dims(self, x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - # from k_samplers sampling.py - def to_d(self, x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / self.append_dims(sigma, x.ndim) - - def get_scalings(self, sigma): - sigma_data = 1. - c_out = -sigma - c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 - return c_out, c_in - - # DiscreteSchedule DS - def DSsigma_to_t(self, sigma, quantize=None): - # quantize = self.quantize if quantize is None else quantize - quantize = False - dists = torch.abs(sigma - self.DSsigmas[:, None]) - if quantize: - return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] - low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] - w = (low - sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def prepare_input(self, latent_in, t, batch_size): - sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 - - sigma_in = torch.cat([sigma] * 2 * batch_size) - # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) - # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) - c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] - - sigma_in = self.DSsigma_to_t(sigma_in) - # s_in = latent_in.new_ones([latent_in.shape[0]]) - # sigma_in = sigma_in * s_in - - return c_out, c_in, sigma_in diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000..3a2de68 --- /dev/null +++ b/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,192 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None + + # setable values + self.num_inference_steps = None + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.derivatives = [] + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] + self.derivatives = [] + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps -- cgit v1.2.3-70-g09d2