From 13b0d9f763269df405d1aeba86213f1c7ce4e7ca Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 15:14:29 +0200 Subject: More consistent euler_a --- infer.py | 2 +- .../stable_diffusion/vlpn_stable_diffusion.py | 17 +++---- schedulers/scheduling_euler_a.py | 59 ++++++++++------------ 3 files changed, 33 insertions(+), 45 deletions(-) diff --git a/infer.py b/infer.py index b440cb6..c40335c 100644 --- a/infer.py +++ b/infer.py @@ -176,7 +176,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): ) else: scheduler = EulerAScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) pipeline = VlpnStableDiffusion( diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): latent_dist = self.vae.encode(latents.to(self.device)).latent_dist latents = latent_dist.sample(generator=generator) latents = 0.18215 * latents + + # expand init_latents for batch_size latents = torch.cat([latents] * batch_size) # get the original timestep using init_timestep @@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): timesteps = torch.tensor( [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device ) - elif isinstance(self.scheduler, EulerAScheduler): - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) else: timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) @@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] elif isinstance(self.scheduler, EulerAScheduler): - sigma = self.scheduler.timesteps[0] - latents = latents * sigma + latents = latents * self.scheduler.sigmas[0] # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = None if isinstance(self.scheduler, EulerAScheduler): - sigma = t.reshape(1) + sigma = self.scheduler.sigmas[t].reshape(1) sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) - # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) - # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample else: # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample elif isinstance(self.scheduler, EulerAScheduler): - if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error - t_prev = self.scheduler.timesteps[t_index+1] - latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t_index, t_index + 1, + latents, **extra_step_kwargs).prev_sample else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -1,7 +1,3 @@ - - -import math -import warnings from typing import Optional, Tuple, Union import numpy as np @@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] @@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): the number of diffusion steps used when generating samples with a pre-trained model. """ - # offset = self.config.steps_offset - - # if "offset" in kwargs: - # warnings.warn( - # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - # " Please pass `steps_offset` to `__init__` instead.", - # DeprecationWarning, - # ) - - # offset = kwargs["offset"] - self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) - self.timesteps = self.sigmas + self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) + self.timesteps = np.arange(0, self.num_inference_steps) def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None @@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.FloatTensor, - timestep: torch.IntTensor, - timestep_prev: torch.IntTensor, + timestep: int, + timestep_prev: int, sample: torch.FloatTensor, generator: None, # ,sigma_hat: float, @@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): returning a tuple, the first element is the sample tensor. """ + s = self.sigmas[timestep] + s_prev = self.sigmas[timestep_prev] latents = sample - sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) - d = to_d(latents, timestep, model_output) - dt = sigma_down - timestep + + sigma_down, sigma_up = get_ancestral_step(s, s_prev) + d = to_d(latents, s, model_output) + dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, generator=generator) * sigma_up + return SchedulerOutput(prev_sample=latents) def step_correct( @@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): return SchedulerOutput(prev_sample=sample_prev) - def add_noise(self, original_samples, noise, timesteps): - raise NotImplementedError() + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + sigmas = self.sigmas.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + sigma = sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples -- cgit v1.2.3-54-g00ecf