from typing import Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput ''' helper functions: append_zero(), t_to_sigma(), get_sigmas(), append_dims(), CFGDenoiserForward(), get_scalings(), DSsigma_to_t(), DiscreteEpsDDPMDenoiserForward(), to_d(), get_ancestral_step() need cleaning ''' def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def t_to_sigma(t, sigmas): t = t.float() low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] def get_sigmas(sigmas, n=None): if n is None: return append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) return append_zero(t_to_sigma(t, sigmas)) # from k_samplers utils.py def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') return x[(...,) + (None,) * dims_to_append] def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): # x_in = torch.cat([x] * 2)#A# concat the latent # sigma_in = torch.cat([sigma] * 2) #A# concat sigma # cond_in = torch.cat([uncond, cond]) # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) # return uncond + (cond - uncond) * cond_scale noise_pred = DiscreteEpsDDPMDenoiserForward( Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) return noise_pred # from k_samplers sampling.py def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) def get_scalings(sigma): sigma_data = 1. c_out = -sigma c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 return c_out, c_in # DiscreteSchedule DS def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): dists = torch.abs(sigma - DSsigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] low, high = DSsigmas[low_idx], DSsigmas[high_idx] w = (low - sigma) / (low - high) w = w.clamp(0, 1) t = (1 - w) * low_idx + w * high_idx return t.view(sigma.shape) def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): sigma = sigma.to(dtype=input.dtype, device=Unet.device) DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample return input + eps * c_out # from k_samplers sampling.py def get_ancestral_step(sigma_from, sigma_to): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up ''' Euler Ancestral Scheduler ''' class EulerAScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`~ConfigMixin.from_config`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] # A# take number of steps as input # A# store 1) number of steps 2) timesteps 3) schedule def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) self.timesteps = self.sigmas[:-1] self.is_scale_input_called = False def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) if self.is_scale_input_called: return sample step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample * sigma self.is_scale_input_called = True return sample def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: torch.Generator = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() step_prev_index = step_index + 1 s = self.sigmas[step_index] s_prev = self.sigmas[step_prev_index] latents = sample sigma_down, sigma_up = get_ancestral_step(s, s_prev) d = to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, generator=generator) * sigma_up return SchedulerOutput(prev_sample=latents) def step_correct( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, sample_prev: torch.FloatTensor, derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO sample_prev (`torch.FloatTensor`): TODO derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative) return SchedulerOutput(prev_sample=sample_prev) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma self.is_scale_input_called = True return noisy_samples