diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 35 |
1 files changed, 5 insertions, 30 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8fbe5f9..a198cf6 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -216,7 +216,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 216 | 216 | ||
| 217 | offset = self.scheduler.config.get("steps_offset", 0) | 217 | offset = self.scheduler.config.get("steps_offset", 0) |
| 218 | init_timestep = num_inference_steps + offset | 218 | init_timestep = num_inference_steps + offset |
| 219 | ensure_sigma = not isinstance(latents, PIL.Image.Image) | ||
| 220 | 219 | ||
| 221 | # get the initial random noise unless the user supplied it | 220 | # get the initial random noise unless the user supplied it |
| 222 | 221 | ||
| @@ -246,13 +245,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 246 | init_timestep = int(num_inference_steps * strength) + offset | 245 | init_timestep = int(num_inference_steps * strength) + offset |
| 247 | init_timestep = min(init_timestep, num_inference_steps) | 246 | init_timestep = min(init_timestep, num_inference_steps) |
| 248 | 247 | ||
| 249 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 248 | timesteps = self.scheduler.timesteps[-init_timestep] |
| 250 | timesteps = torch.tensor( | 249 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
| 251 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 252 | ) | ||
| 253 | else: | ||
| 254 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 255 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 256 | 250 | ||
| 257 | # add noise to latents using the timesteps | 251 | # add noise to latents using the timesteps |
| 258 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | 252 | noise = torch.randn(latents.shape, generator=generator, device=self.device) |
| @@ -263,13 +257,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 263 | if latents.device != self.device: | 257 | if latents.device != self.device: |
| 264 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | 258 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") |
| 265 | 259 | ||
| 266 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 267 | if ensure_sigma: | ||
| 268 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 269 | latents = latents * self.scheduler.sigmas[0] | ||
| 270 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 271 | latents = latents * self.scheduler.sigmas[0] | ||
| 272 | |||
| 273 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 260 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
| 274 | 261 | ||
| 275 | # Some schedulers like PNDM have timesteps as arrays | 262 | # Some schedulers like PNDM have timesteps as arrays |
| @@ -290,19 +277,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 290 | extra_step_kwargs["generator"] = generator | 277 | extra_step_kwargs["generator"] = generator |
| 291 | 278 | ||
| 292 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 279 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
| 293 | t_index = t_start + i | ||
| 294 | |||
| 295 | # expand the latents if we are doing classifier free guidance | 280 | # expand the latents if we are doing classifier free guidance |
| 296 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 281 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 297 | 282 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| 298 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 299 | sigma = self.scheduler.sigmas[t_index] | ||
| 300 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
| 301 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
| 302 | 283 | ||
| 303 | noise_pred = None | 284 | noise_pred = None |
| 304 | if isinstance(self.scheduler, EulerAScheduler): | 285 | if isinstance(self.scheduler, EulerAScheduler): |
| 305 | sigma = self.scheduler.sigmas[t].reshape(1) | 286 | sigma = t.reshape(1) |
| 306 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 287 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
| 307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 288 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
| 308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 289 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
| @@ -316,13 +297,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 316 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 297 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 317 | 298 | ||
| 318 | # compute the previous noisy sample x_t -> x_t-1 | 299 | # compute the previous noisy sample x_t -> x_t-1 |
| 319 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 300 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 320 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | ||
| 321 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 322 | latents = self.scheduler.step(noise_pred, t_index, t_index + 1, | ||
| 323 | latents, **extra_step_kwargs).prev_sample | ||
| 324 | else: | ||
| 325 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
| 326 | 301 | ||
| 327 | # scale and decode the image latents with vae | 302 | # scale and decode the image latents with vae |
| 328 | latents = 1 / 0.18215 * latents | 303 | latents = 1 / 0.18215 * latents |
