diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index e90528d..fc12355 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre | |||
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from transformers import CLIPTextModel, CLIPTokenizer | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from schedulers.scheduling_euler_a import EulerAScheduler | 14 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
15 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.prompt import PromptProcessor |
16 | 16 | ||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
32 | text_encoder: CLIPTextModel, | 32 | text_encoder: CLIPTextModel, |
33 | tokenizer: CLIPTokenizer, | 33 | tokenizer: CLIPTokenizer, |
34 | unet: UNet2DConditionModel, | 34 | unet: UNet2DConditionModel, |
35 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], | 35 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], |
36 | **kwargs, | 36 | **kwargs, |
37 | ): | 37 | ): |
38 | super().__init__() | 38 | super().__init__() |
@@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
225 | init_timestep = int(num_inference_steps * strength) + offset | 225 | init_timestep = int(num_inference_steps * strength) + offset |
226 | init_timestep = min(init_timestep, num_inference_steps) | 226 | init_timestep = min(init_timestep, num_inference_steps) |
227 | 227 | ||
228 | timesteps = self.scheduler.timesteps[-init_timestep] | 228 | if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): |
229 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 229 | timesteps = torch.tensor( |
230 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
231 | ) | ||
232 | else: | ||
233 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
234 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
230 | 235 | ||
231 | # add noise to latents using the timesteps | 236 | # add noise to latents using the timesteps |
232 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 237 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
@@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
259 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 264 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
260 | # expand the latents if we are doing classifier free guidance | 265 | # expand the latents if we are doing classifier free guidance |
261 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 266 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
262 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 267 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) |
263 | 268 | ||
264 | noise_pred = None | 269 | # predict the noise residual |
265 | if isinstance(self.scheduler, EulerAScheduler): | 270 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
266 | c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) | ||
267 | eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
268 | noise_pred = latent_model_input + eps * c_out | ||
269 | else: | ||
270 | # predict the noise residual | ||
271 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
272 | 271 | ||
273 | # perform guidance | 272 | # perform guidance |
274 | if do_classifier_free_guidance: | 273 | if do_classifier_free_guidance: |
@@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
276 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 275 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
277 | 276 | ||
278 | # compute the previous noisy sample x_t -> x_t-1 | 277 | # compute the previous noisy sample x_t -> x_t-1 |
279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 278 | latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample |
280 | 279 | ||
281 | # scale and decode the image latents with vae | 280 | # scale and decode the image latents with vae |
282 | latents = 1 / 0.18215 * latents | 281 | latents = 1 / 0.18215 * latents |