summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index e90528d..fc12355 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler 14from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
15from models.clip.prompt import PromptProcessor 15from models.clip.prompt import PromptProcessor
16 16
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
32 text_encoder: CLIPTextModel, 32 text_encoder: CLIPTextModel,
33 tokenizer: CLIPTokenizer, 33 tokenizer: CLIPTokenizer,
34 unet: UNet2DConditionModel, 34 unet: UNet2DConditionModel,
35 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], 35 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler],
36 **kwargs, 36 **kwargs,
37 ): 37 ):
38 super().__init__() 38 super().__init__()
@@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
225 init_timestep = int(num_inference_steps * strength) + offset 225 init_timestep = int(num_inference_steps * strength) + offset
226 init_timestep = min(init_timestep, num_inference_steps) 226 init_timestep = min(init_timestep, num_inference_steps)
227 227
228 timesteps = self.scheduler.timesteps[-init_timestep] 228 if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler):
229 timesteps = torch.tensor([timesteps] * batch_size, device=self.device) 229 timesteps = torch.tensor(
230 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
231 )
232 else:
233 timesteps = self.scheduler.timesteps[-init_timestep]
234 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
230 235
231 # add noise to latents using the timesteps 236 # add noise to latents using the timesteps
232 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) 237 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
@@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
259 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 264 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
260 # expand the latents if we are doing classifier free guidance 265 # expand the latents if we are doing classifier free guidance
261 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 266 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
262 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 267 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i)
263 268
264 noise_pred = None 269 # predict the noise residual
265 if isinstance(self.scheduler, EulerAScheduler): 270 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
266 c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
267 eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample
268 noise_pred = latent_model_input + eps * c_out
269 else:
270 # predict the noise residual
271 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
272 271
273 # perform guidance 272 # perform guidance
274 if do_classifier_free_guidance: 273 if do_classifier_free_guidance:
@@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
276 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 275 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
277 276
278 # compute the previous noisy sample x_t -> x_t-1 277 # compute the previous noisy sample x_t -> x_t-1
279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 278 latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample
280 279
281 # scale and decode the image latents with vae 280 # scale and decode the image latents with vae
282 latents = 1 / 0.18215 * latents 281 latents = 1 / 0.18215 * latents