diff options
author | Volpeon <git@volpeon.ink> | 2022-11-14 19:48:27 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-14 19:48:27 +0100 |
commit | 90879448c1ae92f39bdcabdf89230891c62e1408 (patch) | |
tree | e58ac942ec52a087e082b024c6cd7127f0327c36 /pipelines | |
parent | Refactoring (diff) | |
download | textual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.tar.gz textual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.tar.bz2 textual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.zip |
Update
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index d6b1cb1..85b0216 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -245,10 +245,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
245 | init_latents = init_latent_dist.sample(generator=generator) | 245 | init_latents = init_latent_dist.sample(generator=generator) |
246 | init_latents = 0.18215 * init_latents | 246 | init_latents = 0.18215 * init_latents |
247 | 247 | ||
248 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: | 248 | if batch_size > init_latents.shape[0]: |
249 | additional_image_per_prompt = batch_size // init_latents.shape[0] | ||
250 | init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) | ||
251 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: | ||
252 | raise ValueError( | 249 | raise ValueError( |
253 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 250 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
254 | ) | 251 | ) |
@@ -367,8 +364,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
367 | do_classifier_free_guidance = guidance_scale > 1.0 | 364 | do_classifier_free_guidance = guidance_scale > 1.0 |
368 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) | 365 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) |
369 | 366 | ||
370 | print(f">>> {device}") | ||
371 | |||
372 | # 3. Encode input prompt | 367 | # 3. Encode input prompt |
373 | text_embeddings = self.encode_prompt( | 368 | text_embeddings = self.encode_prompt( |
374 | prompt, | 369 | prompt, |