From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- schedulers/scheduling_euler_a.py | 45 +++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) (limited to 'schedulers') diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index c6436d8..13ea6b3 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -171,6 +171,9 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] @@ -190,13 +193,33 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) - self.timesteps = np.arange(0, self.num_inference_steps) + self.timesteps = self.sigmas[:-1] + self.is_scale_input_called = False + + def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + if self.is_scale_input_called: + return sample + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample * sigma + self.is_scale_input_called = True + return sample def step( self, model_output: torch.FloatTensor, - timestep: int, - timestep_prev: int, + timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: torch.Generator = None, return_dict: bool = True, @@ -219,8 +242,13 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): returning a tuple, the first element is the sample tensor. """ - s = self.sigmas[timestep] - s_prev = self.sigmas[timestep_prev] + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + step_prev_index = step_index + 1 + + s = self.sigmas[step_index] + s_prev = self.sigmas[step_prev_index] latents = sample sigma_down, sigma_up = get_ancestral_step(s, s_prev) @@ -271,14 +299,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = sigmas[timesteps].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma + self.is_scale_input_called = True return noisy_samples -- cgit v1.2.3-70-g09d2