diff options
| -rw-r--r-- | infer.py | 2 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 17 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 59 | 
3 files changed, 33 insertions, 45 deletions
| @@ -176,7 +176,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
| 176 | ) | 176 | ) | 
| 177 | else: | 177 | else: | 
| 178 | scheduler = EulerAScheduler( | 178 | scheduler = EulerAScheduler( | 
| 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 
| 180 | ) | 180 | ) | 
| 181 | 181 | ||
| 182 | pipeline = VlpnStableDiffusion( | 182 | pipeline = VlpnStableDiffusion( | 
| diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 
| 186 | latents = latent_dist.sample(generator=generator) | 186 | latents = latent_dist.sample(generator=generator) | 
| 187 | latents = 0.18215 * latents | 187 | latents = 0.18215 * latents | 
| 188 | |||
| 189 | # expand init_latents for batch_size | ||
| 188 | latents = torch.cat([latents] * batch_size) | 190 | latents = torch.cat([latents] * batch_size) | 
| 189 | 191 | ||
| 190 | # get the original timestep using init_timestep | 192 | # get the original timestep using init_timestep | 
| @@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 195 | timesteps = torch.tensor( | 197 | timesteps = torch.tensor( | 
| 196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | 
| 197 | ) | 199 | ) | 
| 198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 201 | else: | 200 | else: | 
| 202 | timesteps = self.scheduler.timesteps[-init_timestep] | 201 | timesteps = self.scheduler.timesteps[-init_timestep] | 
| 203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | 
| @@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 273 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 
| 274 | latents = latents * self.scheduler.sigmas[0] | 273 | latents = latents * self.scheduler.sigmas[0] | 
| 275 | elif isinstance(self.scheduler, EulerAScheduler): | 274 | elif isinstance(self.scheduler, EulerAScheduler): | 
| 276 | sigma = self.scheduler.timesteps[0] | 275 | latents = latents * self.scheduler.sigmas[0] | 
| 277 | latents = latents * sigma | ||
| 278 | 276 | ||
| 279 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 
| 280 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 
| @@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | 299 | ||
| 302 | noise_pred = None | 300 | noise_pred = None | 
| 303 | if isinstance(self.scheduler, EulerAScheduler): | 301 | if isinstance(self.scheduler, EulerAScheduler): | 
| 304 | sigma = t.reshape(1) | 302 | sigma = self.scheduler.sigmas[t].reshape(1) | 
| 305 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 303 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 
| 306 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
| 307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 304 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 
| 308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 305 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 
| 309 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
| 310 | else: | 306 | else: | 
| 311 | # predict the noise residual | 307 | # predict the noise residual | 
| 312 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 308 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 
| @@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 320 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 316 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 
| 321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | 317 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | 
| 322 | elif isinstance(self.scheduler, EulerAScheduler): | 318 | elif isinstance(self.scheduler, EulerAScheduler): | 
| 323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 319 | latents = self.scheduler.step(noise_pred, t_index, t_index + 1, | 
| 324 | t_prev = self.scheduler.timesteps[t_index+1] | 320 | latents, **extra_step_kwargs).prev_sample | 
| 325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
| 326 | else: | 321 | else: | 
| 327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 
| 328 | 323 | ||
| diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 9fbedaa..1b1c9cf 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -1,7 +1,3 @@ | |||
| 1 | |||
| 2 | |||
| 3 | import math | ||
| 4 | import warnings | ||
| 5 | from typing import Optional, Tuple, Union | 1 | from typing import Optional, Tuple, Union | 
| 6 | 2 | ||
| 7 | import numpy as np | 3 | import numpy as np | 
| @@ -157,9 +153,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 157 | beta_end: float = 0.02, | 153 | beta_end: float = 0.02, | 
| 158 | beta_schedule: str = "linear", | 154 | beta_schedule: str = "linear", | 
| 159 | trained_betas: Optional[np.ndarray] = None, | 155 | trained_betas: Optional[np.ndarray] = None, | 
| 160 | clip_sample: bool = True, | ||
| 161 | set_alpha_to_one: bool = True, | ||
| 162 | steps_offset: int = 0, | ||
| 163 | ): | 156 | ): | 
| 164 | if trained_betas is not None: | 157 | if trained_betas is not None: | 
| 165 | self.betas = torch.from_numpy(trained_betas) | 158 | self.betas = torch.from_numpy(trained_betas) | 
| @@ -177,12 +170,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 177 | self.alphas = 1.0 - self.betas | 170 | self.alphas = 1.0 - self.betas | 
| 178 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 171 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 
| 179 | 172 | ||
| 180 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
| 181 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
| 182 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
| 183 | # whether we use the final alpha of the "non-previous" one. | ||
| 184 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
| 185 | |||
| 186 | # setable values | 173 | # setable values | 
| 187 | self.num_inference_steps = None | 174 | self.num_inference_steps = None | 
| 188 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 175 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 
| @@ -199,21 +186,10 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 199 | the number of diffusion steps used when generating samples with a pre-trained model. | 186 | the number of diffusion steps used when generating samples with a pre-trained model. | 
| 200 | """ | 187 | """ | 
| 201 | 188 | ||
| 202 | # offset = self.config.steps_offset | ||
| 203 | |||
| 204 | # if "offset" in kwargs: | ||
| 205 | # warnings.warn( | ||
| 206 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
| 207 | # " Please pass `steps_offset` to `__init__` instead.", | ||
| 208 | # DeprecationWarning, | ||
| 209 | # ) | ||
| 210 | |||
| 211 | # offset = kwargs["offset"] | ||
| 212 | |||
| 213 | self.num_inference_steps = num_inference_steps | 189 | self.num_inference_steps = num_inference_steps | 
| 214 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 190 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 
| 215 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) | 191 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 
| 216 | self.timesteps = self.sigmas | 192 | self.timesteps = np.arange(0, self.num_inference_steps) | 
| 217 | 193 | ||
| 218 | def add_noise_to_input( | 194 | def add_noise_to_input( | 
| 219 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | 195 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | 
| @@ -239,8 +215,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 239 | def step( | 215 | def step( | 
| 240 | self, | 216 | self, | 
| 241 | model_output: torch.FloatTensor, | 217 | model_output: torch.FloatTensor, | 
| 242 | timestep: torch.IntTensor, | 218 | timestep: int, | 
| 243 | timestep_prev: torch.IntTensor, | 219 | timestep_prev: int, | 
| 244 | sample: torch.FloatTensor, | 220 | sample: torch.FloatTensor, | 
| 245 | generator: None, | 221 | generator: None, | 
| 246 | # ,sigma_hat: float, | 222 | # ,sigma_hat: float, | 
| @@ -266,13 +242,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 266 | returning a tuple, the first element is the sample tensor. | 242 | returning a tuple, the first element is the sample tensor. | 
| 267 | 243 | ||
| 268 | """ | 244 | """ | 
| 245 | s = self.sigmas[timestep] | ||
| 246 | s_prev = self.sigmas[timestep_prev] | ||
| 269 | latents = sample | 247 | latents = sample | 
| 270 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 248 | |
| 271 | d = to_d(latents, timestep, model_output) | 249 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) | 
| 272 | dt = sigma_down - timestep | 250 | d = to_d(latents, s, model_output) | 
| 251 | dt = sigma_down - s | ||
| 273 | latents = latents + d * dt | 252 | latents = latents + d * dt | 
| 274 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 253 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 
| 275 | generator=generator) * sigma_up | 254 | generator=generator) * sigma_up | 
| 255 | |||
| 276 | return SchedulerOutput(prev_sample=latents) | 256 | return SchedulerOutput(prev_sample=latents) | 
| 277 | 257 | ||
| 278 | def step_correct( | 258 | def step_correct( | 
| @@ -311,5 +291,18 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 311 | 291 | ||
| 312 | return SchedulerOutput(prev_sample=sample_prev) | 292 | return SchedulerOutput(prev_sample=sample_prev) | 
| 313 | 293 | ||
| 314 | def add_noise(self, original_samples, noise, timesteps): | 294 | def add_noise( | 
| 315 | raise NotImplementedError() | 295 | self, | 
| 296 | original_samples: torch.FloatTensor, | ||
| 297 | noise: torch.FloatTensor, | ||
| 298 | timesteps: torch.IntTensor, | ||
| 299 | ) -> torch.FloatTensor: | ||
| 300 | sigmas = self.sigmas.to(original_samples.device) | ||
| 301 | timesteps = timesteps.to(original_samples.device) | ||
| 302 | |||
| 303 | sigma = sigmas[timesteps].flatten() | ||
| 304 | while len(sigma.shape) < len(original_samples.shape): | ||
| 305 | sigma = sigma.unsqueeze(-1) | ||
| 306 | |||
| 307 | noisy_samples = original_samples + noise * sigma | ||
| 308 | return noisy_samples | ||
