summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 78a34d5..141b9a7 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor
27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 28
29 29
30def preprocess(image, w, h): 30def preprocess(image):
31 w, h = image.size
32 w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
31 image = image.resize((w, h), resample=PIL.Image.LANCZOS) 33 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
32 image = np.array(image).astype(np.float32) / 255.0 34 image = np.array(image).astype(np.float32) / 255.0
33 image = image[None].transpose(0, 3, 1, 2) 35 image = image[None].transpose(0, 3, 1, 2)
@@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
310 guidance_scale: Optional[float] = 7.5, 312 guidance_scale: Optional[float] = 7.5,
311 eta: Optional[float] = 0.0, 313 eta: Optional[float] = 0.0,
312 generator: Optional[torch.Generator] = None, 314 generator: Optional[torch.Generator] = None,
313 latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 315 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
314 output_type: Optional[str] = "pil", 316 output_type: Optional[str] = "pil",
315 return_dict: bool = True, 317 return_dict: bool = True,
316 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 318 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 batch_size = len(prompt) 375 batch_size = len(prompt)
374 device = self.execution_device 376 device = self.execution_device
375 do_classifier_free_guidance = guidance_scale > 1.0 377 do_classifier_free_guidance = guidance_scale > 1.0
376 latents_are_image = isinstance(latents_or_image, PIL.Image.Image) 378 latents_are_image = isinstance(image, PIL.Image.Image)
377 379
378 # 3. Encode input prompt 380 # 3. Encode input prompt
379 text_embeddings = self.encode_prompt( 381 text_embeddings = self.encode_prompt(
@@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
391 # 5. Prepare latent variables 393 # 5. Prepare latent variables
392 num_channels_latents = self.unet.in_channels 394 num_channels_latents = self.unet.in_channels
393 if latents_are_image: 395 if latents_are_image:
396 image = preprocess(image)
394 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 397 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
395 latents = self.prepare_latents_from_image( 398 latents = self.prepare_latents_from_image(
396 latents_or_image, 399 image,
397 latent_timestep, 400 latent_timestep,
398 batch_size, 401 batch_size,
399 num_images_per_prompt, 402 num_images_per_prompt,
@@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
411 text_embeddings.dtype, 414 text_embeddings.dtype,
412 device, 415 device,
413 generator, 416 generator,
414 latents_or_image, 417 image,
415 ) 418 )
416 419
417 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 420 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline