diff options
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 78a34d5..141b9a7 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor | |||
| 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 28 | 28 | ||
| 29 | 29 | ||
| 30 | def preprocess(image, w, h): | 30 | def preprocess(image): |
| 31 | w, h = image.size | ||
| 32 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 | ||
| 31 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | 33 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) |
| 32 | image = np.array(image).astype(np.float32) / 255.0 | 34 | image = np.array(image).astype(np.float32) / 255.0 |
| 33 | image = image[None].transpose(0, 3, 1, 2) | 35 | image = image[None].transpose(0, 3, 1, 2) |
| @@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 310 | guidance_scale: Optional[float] = 7.5, | 312 | guidance_scale: Optional[float] = 7.5, |
| 311 | eta: Optional[float] = 0.0, | 313 | eta: Optional[float] = 0.0, |
| 312 | generator: Optional[torch.Generator] = None, | 314 | generator: Optional[torch.Generator] = None, |
| 313 | latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 315 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 314 | output_type: Optional[str] = "pil", | 316 | output_type: Optional[str] = "pil", |
| 315 | return_dict: bool = True, | 317 | return_dict: bool = True, |
| 316 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 318 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| @@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | batch_size = len(prompt) | 375 | batch_size = len(prompt) |
| 374 | device = self.execution_device | 376 | device = self.execution_device |
| 375 | do_classifier_free_guidance = guidance_scale > 1.0 | 377 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 376 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) | 378 | latents_are_image = isinstance(image, PIL.Image.Image) |
| 377 | 379 | ||
| 378 | # 3. Encode input prompt | 380 | # 3. Encode input prompt |
| 379 | text_embeddings = self.encode_prompt( | 381 | text_embeddings = self.encode_prompt( |
| @@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 391 | # 5. Prepare latent variables | 393 | # 5. Prepare latent variables |
| 392 | num_channels_latents = self.unet.in_channels | 394 | num_channels_latents = self.unet.in_channels |
| 393 | if latents_are_image: | 395 | if latents_are_image: |
| 396 | image = preprocess(image) | ||
| 394 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 397 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| 395 | latents = self.prepare_latents_from_image( | 398 | latents = self.prepare_latents_from_image( |
| 396 | latents_or_image, | 399 | image, |
| 397 | latent_timestep, | 400 | latent_timestep, |
| 398 | batch_size, | 401 | batch_size, |
| 399 | num_images_per_prompt, | 402 | num_images_per_prompt, |
| @@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 411 | text_embeddings.dtype, | 414 | text_embeddings.dtype, |
| 412 | device, | 415 | device, |
| 413 | generator, | 416 | generator, |
| 414 | latents_or_image, | 417 | image, |
| 415 | ) | 418 | ) |
| 416 | 419 | ||
| 417 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 420 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
