from typing import Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput class EulerAScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`~ConfigMixin.from_config`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, tensor_format: str = "pt", num_inference_steps=None, device='cuda' ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas).to(device) if beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) ** 2 else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.device = device self.tensor_format = tensor_format self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() # get sigmas self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) self.set_format(tensor_format=tensor_format) # A# take number of steps as input # A# store 1) number of steps 2) timesteps 3) schedule def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) self.timesteps = self.sigmas[:-1] self.is_scale_input_called = False def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) if self.is_scale_input_called: return sample step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample * sigma self.is_scale_input_called = True return sample def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: torch.Generator = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() step_prev_index = step_index + 1 s = self.sigmas[step_index] s_prev = self.sigmas[step_prev_index] latents = sample sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) d = self.to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, generator=generator) * sigma_up return SchedulerOutput(prev_sample=latents) def step_correct( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, sample_prev: torch.FloatTensor, derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO sample_prev (`torch.FloatTensor`): TODO derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative) return SchedulerOutput(prev_sample=sample_prev) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma self.is_scale_input_called = True return noisy_samples # from k_samplers sampling.py def get_ancestral_step(self, sigma_from, sigma_to): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up def t_to_sigma(self, t, sigmas): t = t.float() low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] def append_zero(self, x): return torch.cat([x, x.new_zeros([1])]) def get_sigmas(self, sigmas, n=None): if n is None: return self.append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 device = self.device t = torch.linspace(t_max, 0, n, device=device) # t = torch.linspace(t_max, 0, n, device=sigmas.device) return self.append_zero(self.t_to_sigma(t, sigmas)) # from k_samplers utils.py def append_dims(self, x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') return x[(...,) + (None,) * dims_to_append] # from k_samplers sampling.py def to_d(self, x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / self.append_dims(sigma, x.ndim) def get_scalings(self, sigma): sigma_data = 1. c_out = -sigma c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 return c_out, c_in # DiscreteSchedule DS def DSsigma_to_t(self, sigma, quantize=None): # quantize = self.quantize if quantize is None else quantize quantize = False dists = torch.abs(sigma - self.DSsigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] w = (low - sigma) / (low - high) w = w.clamp(0, 1) t = (1 - w) * low_idx + w * high_idx return t.view(sigma.shape) def prepare_input(self, latent_in, t, batch_size): sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 sigma_in = torch.cat([sigma] * 2 * batch_size) # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] sigma_in = self.DSsigma_to_t(sigma_in) # s_in = latent_in.new_ones([latent_in.shape[0]]) # sigma_in = sigma_in * s_in return c_out, c_in, sigma_in