From 13b0d9f763269df405d1aeba86213f1c7ce4e7ca Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 15:14:29 +0200 Subject: More consistent euler_a --- schedulers/scheduling_euler_a.py | 59 ++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 33 deletions(-) (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -1,7 +1,3 @@ - - -import math -import warnings from typing import Optional, Tuple, Union import numpy as np @@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] @@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): the number of diffusion steps used when generating samples with a pre-trained model. """ - # offset = self.config.steps_offset - - # if "offset" in kwargs: - # warnings.warn( - # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - # " Please pass `steps_offset` to `__init__` instead.", - # DeprecationWarning, - # ) - - # offset = kwargs["offset"] - self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) - self.timesteps = self.sigmas + self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) + self.timesteps = np.arange(0, self.num_inference_steps) def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None @@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.FloatTensor, - timestep: torch.IntTensor, - timestep_prev: torch.IntTensor, + timestep: int, + timestep_prev: int, sample: torch.FloatTensor, generator: None, # ,sigma_hat: float, @@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): returning a tuple, the first element is the sample tensor. """ + s = self.sigmas[timestep] + s_prev = self.sigmas[timestep_prev] latents = sample - sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) - d = to_d(latents, timestep, model_output) - dt = sigma_down - timestep + + sigma_down, sigma_up = get_ancestral_step(s, s_prev) + d = to_d(latents, s, model_output) + dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, generator=generator) * sigma_up + return SchedulerOutput(prev_sample=latents) def step_correct( @@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): return SchedulerOutput(prev_sample=sample_prev) - def add_noise(self, original_samples, noise, timesteps): - raise NotImplementedError() + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + sigmas = self.sigmas.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + sigma = sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples -- cgit v1.2.3-70-g09d2