summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-14 19:48:27 +0100
committerVolpeon <git@volpeon.ink>2022-11-14 19:48:27 +0100
commit90879448c1ae92f39bdcabdf89230891c62e1408 (patch)
treee58ac942ec52a087e082b024c6cd7127f0327c36
parentRefactoring (diff)
downloadtextual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.tar.gz
textual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.tar.bz2
textual-inversion-diff-90879448c1ae92f39bdcabdf89230891c62e1408.zip
Update
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py7
1 files changed, 1 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index d6b1cb1..85b0216 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -245,10 +245,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
245 init_latents = init_latent_dist.sample(generator=generator) 245 init_latents = init_latent_dist.sample(generator=generator)
246 init_latents = 0.18215 * init_latents 246 init_latents = 0.18215 * init_latents
247 247
248 if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: 248 if batch_size > init_latents.shape[0]:
249 additional_image_per_prompt = batch_size // init_latents.shape[0]
250 init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
251 elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
252 raise ValueError( 249 raise ValueError(
253 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 250 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
254 ) 251 )
@@ -367,8 +364,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
367 do_classifier_free_guidance = guidance_scale > 1.0 364 do_classifier_free_guidance = guidance_scale > 1.0
368 latents_are_image = isinstance(latents_or_image, PIL.Image.Image) 365 latents_are_image = isinstance(latents_or_image, PIL.Image.Image)
369 366
370 print(f">>> {device}")
371
372 # 3. Encode input prompt 367 # 3. Encode input prompt
373 text_embeddings = self.encode_prompt( 368 text_embeddings = self.encode_prompt(
374 prompt, 369 prompt,