From dd02ace41f69541044e9db106feaa76bf02da8f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 12 Dec 2022 08:05:06 +0100 Subject: Dreambooth: Support loading Textual Inversion embeddings --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 78a34d5..141b9a7 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def preprocess(image, w, h): +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, - latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size = len(prompt) device = self.execution_device do_classifier_free_guidance = guidance_scale > 1.0 - latents_are_image = isinstance(latents_or_image, PIL.Image.Image) + latents_are_image = isinstance(image, PIL.Image.Image) # 3. Encode input prompt text_embeddings = self.encode_prompt( @@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline): # 5. Prepare latent variables num_channels_latents = self.unet.in_channels if latents_are_image: + image = preprocess(image) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents = self.prepare_latents_from_image( - latents_or_image, + image, latent_timestep, batch_size, num_images_per_prompt, @@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline): text_embeddings.dtype, device, generator, - latents_or_image, + image, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -- cgit v1.2.3-70-g09d2