diff options
author | Volpeon <git@volpeon.ink> | 2022-12-12 08:05:06 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-12 08:05:06 +0100 |
commit | dd02ace41f69541044e9db106feaa76bf02da8f6 (patch) | |
tree | 8f6a8735acac9ebcf7396a40c632fa81c936701a /pipelines/stable_diffusion | |
parent | Remove embedding checkpoints from Dreambooth training (diff) | |
download | textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.gz textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.bz2 textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.zip |
Dreambooth: Support loading Textual Inversion embeddings
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 78a34d5..141b9a7 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor | |||
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
28 | 28 | ||
29 | 29 | ||
30 | def preprocess(image, w, h): | 30 | def preprocess(image): |
31 | w, h = image.size | ||
32 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 | ||
31 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | 33 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) |
32 | image = np.array(image).astype(np.float32) / 255.0 | 34 | image = np.array(image).astype(np.float32) / 255.0 |
33 | image = image[None].transpose(0, 3, 1, 2) | 35 | image = image[None].transpose(0, 3, 1, 2) |
@@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
310 | guidance_scale: Optional[float] = 7.5, | 312 | guidance_scale: Optional[float] = 7.5, |
311 | eta: Optional[float] = 0.0, | 313 | eta: Optional[float] = 0.0, |
312 | generator: Optional[torch.Generator] = None, | 314 | generator: Optional[torch.Generator] = None, |
313 | latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 315 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
314 | output_type: Optional[str] = "pil", | 316 | output_type: Optional[str] = "pil", |
315 | return_dict: bool = True, | 317 | return_dict: bool = True, |
316 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 318 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | batch_size = len(prompt) | 375 | batch_size = len(prompt) |
374 | device = self.execution_device | 376 | device = self.execution_device |
375 | do_classifier_free_guidance = guidance_scale > 1.0 | 377 | do_classifier_free_guidance = guidance_scale > 1.0 |
376 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) | 378 | latents_are_image = isinstance(image, PIL.Image.Image) |
377 | 379 | ||
378 | # 3. Encode input prompt | 380 | # 3. Encode input prompt |
379 | text_embeddings = self.encode_prompt( | 381 | text_embeddings = self.encode_prompt( |
@@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
391 | # 5. Prepare latent variables | 393 | # 5. Prepare latent variables |
392 | num_channels_latents = self.unet.in_channels | 394 | num_channels_latents = self.unet.in_channels |
393 | if latents_are_image: | 395 | if latents_are_image: |
396 | image = preprocess(image) | ||
394 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 397 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
395 | latents = self.prepare_latents_from_image( | 398 | latents = self.prepare_latents_from_image( |
396 | latents_or_image, | 399 | image, |
397 | latent_timestep, | 400 | latent_timestep, |
398 | batch_size, | 401 | batch_size, |
399 | num_images_per_prompt, | 402 | num_images_per_prompt, |
@@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
411 | text_embeddings.dtype, | 414 | text_embeddings.dtype, |
412 | device, | 415 | device, |
413 | generator, | 416 | generator, |
414 | latents_or_image, | 417 | image, |
415 | ) | 418 | ) |
416 | 419 | ||
417 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 420 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |