diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 32 |
1 files changed, 9 insertions, 23 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,7 +20,7 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.utils import logging | 23 | from diffusers.utils import logging, randn_tensor |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
| 26 | 26 | ||
| @@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 250 | 250 | ||
| 251 | return timesteps | 251 | return timesteps |
| 252 | 252 | ||
| 253 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 253 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| 254 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 254 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| 255 | 255 | ||
| 256 | if isinstance(generator, list) and len(generator) != batch_size: | 256 | if isinstance(generator, list) and len(generator) != batch_size: |
| 257 | raise ValueError( | 257 | raise ValueError( |
| @@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 260 | ) | 260 | ) |
| 261 | 261 | ||
| 262 | if latents is None: | 262 | if latents is None: |
| 263 | rand_device = "cpu" if device.type == "mps" else device | 263 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| 264 | |||
| 265 | if isinstance(generator, list): | ||
| 266 | shape = (1,) + shape[1:] | ||
| 267 | latents = [ | ||
| 268 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
| 269 | for i in range(batch_size) | ||
| 270 | ] | ||
| 271 | latents = torch.cat(latents, dim=0).to(device) | ||
| 272 | else: | ||
| 273 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) | ||
| 274 | else: | 264 | else: |
| 275 | if latents.shape != shape: | 265 | latents = latents.to(device=device, dtype=dtype) |
| 276 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
| 277 | latents = latents.to(device) | ||
| 278 | 266 | ||
| 279 | # scale the initial noise by the standard deviation required by the scheduler | 267 | # scale the initial noise by the standard deviation required by the scheduler |
| 280 | latents = latents * self.scheduler.init_noise_sigma | 268 | latents = latents * self.scheduler.init_noise_sigma |
| 281 | 269 | ||
| 282 | return latents | 270 | return latents |
| 283 | 271 | ||
| 284 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | 272 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
| 285 | init_image = init_image.to(device=device, dtype=dtype) | 273 | init_image = init_image.to(device=device, dtype=dtype) |
| 286 | init_latent_dist = self.vae.encode(init_image).latent_dist | 274 | init_latent_dist = self.vae.encode(init_image).latent_dist |
| 287 | init_latents = init_latent_dist.sample(generator=generator) | 275 | init_latents = init_latent_dist.sample(generator=generator) |
| @@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 292 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 280 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
| 293 | ) | 281 | ) |
| 294 | else: | 282 | else: |
| 295 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | 283 | init_latents = torch.cat([init_latents] * batch_size, dim=0) |
| 296 | 284 | ||
| 297 | # add noise to latents using the timesteps | 285 | # add noise to latents using the timesteps |
| 298 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 286 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) |
| @@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 430 | latents = self.prepare_latents_from_image( | 418 | latents = self.prepare_latents_from_image( |
| 431 | image, | 419 | image, |
| 432 | latent_timestep, | 420 | latent_timestep, |
| 433 | batch_size, | 421 | batch_size * num_images_per_prompt, |
| 434 | num_images_per_prompt, | ||
| 435 | text_embeddings.dtype, | 422 | text_embeddings.dtype, |
| 436 | device, | 423 | device, |
| 437 | generator | 424 | generator |
| 438 | ) | 425 | ) |
| 439 | else: | 426 | else: |
| 440 | latents = self.prepare_latents( | 427 | latents = self.prepare_latents( |
| 441 | batch_size, | 428 | batch_size * num_images_per_prompt, |
| 442 | num_images_per_prompt, | ||
| 443 | num_channels_latents, | 429 | num_channels_latents, |
| 444 | height, | 430 | height, |
| 445 | width, | 431 | width, |
