diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 5f4fc38..f27be78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | import inspect | 1 | import inspect |
| 2 | import warnings | 2 | import warnings |
| 3 | import math | 3 | import math |
| 4 | from typing import List, Dict, Any, Optional, Union, Callable | 4 | from typing import List, Dict, Any, Optional, Union, Callable, Literal |
| 5 | 5 | ||
| 6 | import numpy as np | 6 | import numpy as np |
| 7 | import torch | 7 | import torch |
| @@ -22,7 +22,7 @@ from diffusers import ( | |||
| 22 | PNDMScheduler, | 22 | PNDMScheduler, |
| 23 | ) | 23 | ) |
| 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 25 | from diffusers.utils import logging | 25 | from diffusers.utils import logging, randn_tensor |
| 26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
| 27 | 27 | ||
| 28 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
| @@ -312,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 312 | ).expand(batch_size, 3, width, height) | 312 | ).expand(batch_size, 3, width, height) |
| 313 | return (1.4 * noise).clamp(-1, 1) | 313 | return (1.4 * noise).clamp(-1, 1) |
| 314 | 314 | ||
| 315 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 315 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
| 316 | init_image = init_image.to(device=device, dtype=dtype) | 316 | init_image = init_image.to(device=device, dtype=dtype) |
| 317 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 317 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
| 318 | init_latents = self.vae.config.scaling_factor * init_latents | 318 | init_latents = self.vae.config.scaling_factor * init_latents |
| @@ -334,6 +334,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 334 | 334 | ||
| 335 | return latents | 335 | return latents |
| 336 | 336 | ||
| 337 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | ||
| 338 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | ||
| 339 | if isinstance(generator, list) and len(generator) != batch_size: | ||
| 340 | raise ValueError( | ||
| 341 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
| 342 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
| 343 | ) | ||
| 344 | |||
| 345 | if latents is None: | ||
| 346 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
| 347 | else: | ||
| 348 | latents = latents.to(device) | ||
| 349 | |||
| 350 | # scale the initial noise by the standard deviation required by the scheduler | ||
| 351 | latents = latents * self.scheduler.init_noise_sigma | ||
| 352 | return latents | ||
| 353 | |||
| 337 | def prepare_extra_step_kwargs(self, generator, eta): | 354 | def prepare_extra_step_kwargs(self, generator, eta): |
| 338 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 355 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 339 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 356 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| @@ -373,7 +390,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | sag_scale: float = 0.75, | 390 | sag_scale: float = 0.75, |
| 374 | eta: float = 0.0, | 391 | eta: float = 0.0, |
| 375 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 392 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 376 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 393 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, |
| 377 | output_type: str = "pil", | 394 | output_type: str = "pil", |
| 378 | return_dict: bool = True, | 395 | return_dict: bool = True, |
| 379 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 396 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| @@ -443,8 +460,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 443 | # 2. Define call parameters | 460 | # 2. Define call parameters |
| 444 | batch_size = len(prompt) | 461 | batch_size = len(prompt) |
| 445 | device = self.execution_device | 462 | device = self.execution_device |
| 463 | num_channels_latents = self.unet.in_channels | ||
| 446 | do_classifier_free_guidance = guidance_scale > 1.0 | 464 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 447 | do_self_attention_guidance = sag_scale > 0.0 | 465 | do_self_attention_guidance = sag_scale > 0.0 |
| 466 | prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" | ||
| 448 | 467 | ||
| 449 | # 3. Encode input prompt | 468 | # 3. Encode input prompt |
| 450 | prompt_embeds = self.encode_prompt( | 469 | prompt_embeds = self.encode_prompt( |
| @@ -458,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 458 | # 4. Prepare latent variables | 477 | # 4. Prepare latent variables |
| 459 | if isinstance(image, PIL.Image.Image): | 478 | if isinstance(image, PIL.Image.Image): |
| 460 | image = preprocess(image) | 479 | image = preprocess(image) |
| 461 | elif image is None: | 480 | elif image == "noise": |
| 462 | image = self.prepare_image( | 481 | image = self.prepare_image( |
| 463 | batch_size * num_images_per_prompt, | 482 | batch_size * num_images_per_prompt, |
| 464 | width, | 483 | width, |
| @@ -474,14 +493,26 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 474 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 493 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| 475 | 494 | ||
| 476 | # 6. Prepare latent variables | 495 | # 6. Prepare latent variables |
| 477 | latents = self.prepare_latents( | 496 | if prep_from_image: |
| 478 | image, | 497 | latents = self.prepare_latents_from_image( |
| 479 | latent_timestep, | 498 | image, |
| 480 | batch_size * num_images_per_prompt, | 499 | latent_timestep, |
| 481 | prompt_embeds.dtype, | 500 | batch_size * num_images_per_prompt, |
| 482 | device, | 501 | prompt_embeds.dtype, |
| 483 | generator | 502 | device, |
| 484 | ) | 503 | generator |
| 504 | ) | ||
| 505 | else: | ||
| 506 | latents = self.prepare_latents( | ||
| 507 | batch_size, | ||
| 508 | num_channels_latents, | ||
| 509 | height, | ||
| 510 | width, | ||
| 511 | prompt_embeds.dtype, | ||
| 512 | device, | ||
| 513 | generator, | ||
| 514 | image | ||
| 515 | ) | ||
| 485 | 516 | ||
| 486 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 517 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
| 487 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 518 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
