summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py95
1 files changed, 39 insertions, 56 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cb09fe1..c4f7401 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -293,53 +293,39 @@ class VlpnStableDiffusion(DiffusionPipeline):
293 293
294 return prompt_embeds 294 return prompt_embeds
295 295
296 def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): 296 def get_timesteps(self, num_inference_steps, strength, device):
297 if latents_are_image: 297 # get the original timestep using init_timestep
298 # get the original timestep using init_timestep 298 offset = self.scheduler.config.get("steps_offset", 0)
299 offset = self.scheduler.config.get("steps_offset", 0) 299 init_timestep = int(num_inference_steps * strength) + offset
300 init_timestep = int(num_inference_steps * strength) + offset 300 init_timestep = min(init_timestep, num_inference_steps)
301 init_timestep = min(init_timestep, num_inference_steps) 301
302 302 t_start = max(num_inference_steps - init_timestep + offset, 0)
303 t_start = max(num_inference_steps - init_timestep + offset, 0) 303 timesteps = self.scheduler.timesteps[t_start:]
304 timesteps = self.scheduler.timesteps[t_start:]
305 else:
306 timesteps = self.scheduler.timesteps
307 304
308 timesteps = timesteps.to(device) 305 timesteps = timesteps.to(device)
309 306
310 return timesteps 307 return timesteps
311 308
312 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None):
313 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 310 return torch.randn(
314 311 (batch_size, 1, 1, 1),
315 if isinstance(generator, list) and len(generator) != batch_size: 312 dtype=dtype,
316 raise ValueError( 313 device=device,
317 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 314 generator=generator
318 f" size of {batch_size}. Make sure the batch size matches the length of the generators." 315 ).expand(batch_size, 3, width, height)
319 )
320 316
321 if latents is None: 317 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None):
322 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
323 else:
324 latents = latents.to(device=device, dtype=dtype)
325
326 # scale the initial noise by the standard deviation required by the scheduler
327 latents = latents * self.scheduler.init_noise_sigma
328
329 return latents
330
331 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
332 init_image = init_image.to(device=device, dtype=dtype) 318 init_image = init_image.to(device=device, dtype=dtype)
333 init_latent_dist = self.vae.encode(init_image).latent_dist 319 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
334 init_latents = init_latent_dist.sample(generator=generator) 320 init_latents = self.vae.config.scaling_factor * init_latents
335 init_latents = 0.18215 * init_latents
336 321
337 if batch_size > init_latents.shape[0]: 322 if batch_size % init_latents.shape[0] != 0:
338 raise ValueError( 323 raise ValueError(
339 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 324 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
340 ) 325 )
341 else: 326 else:
342 init_latents = torch.cat([init_latents] * batch_size, dim=0) 327 batch_multiplier = batch_size // init_latents.shape[0]
328 init_latents = torch.cat([init_latents] * batch_multiplier, dim=0)
343 329
344 # add noise to latents using the timesteps 330 # add noise to latents using the timesteps
345 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 331 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
@@ -368,7 +354,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
368 return extra_step_kwargs 354 return extra_step_kwargs
369 355
370 def decode_latents(self, latents): 356 def decode_latents(self, latents):
371 latents = 1 / 0.18215 * latents 357 latents = 1 / self.vae.config.scaling_factor * latents
372 image = self.vae.decode(latents).sample 358 image = self.vae.decode(latents).sample
373 image = (image / 2 + 0.5).clamp(0, 1) 359 image = (image / 2 + 0.5).clamp(0, 1)
374 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 360 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -381,7 +367,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
381 prompt: Union[str, List[str], List[int], List[List[int]]], 367 prompt: Union[str, List[str], List[int], List[List[int]]],
382 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, 368 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None,
383 num_images_per_prompt: int = 1, 369 num_images_per_prompt: int = 1,
384 strength: float = 0.8, 370 strength: float = 1.0,
385 height: Optional[int] = None, 371 height: Optional[int] = None,
386 width: Optional[int] = None, 372 width: Optional[int] = None,
387 num_inference_steps: int = 50, 373 num_inference_steps: int = 50,
@@ -461,7 +447,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
461 device = self.execution_device 447 device = self.execution_device
462 do_classifier_free_guidance = guidance_scale > 1.0 448 do_classifier_free_guidance = guidance_scale > 1.0
463 do_self_attention_guidance = sag_scale > 0.0 449 do_self_attention_guidance = sag_scale > 0.0
464 latents_are_image = isinstance(image, PIL.Image.Image)
465 450
466 # 3. Encode input prompt 451 # 3. Encode input prompt
467 prompt_embeds = self.encode_prompt( 452 prompt_embeds = self.encode_prompt(
@@ -474,33 +459,31 @@ class VlpnStableDiffusion(DiffusionPipeline):
474 459
475 # 4. Prepare timesteps 460 # 4. Prepare timesteps
476 self.scheduler.set_timesteps(num_inference_steps, device=device) 461 self.scheduler.set_timesteps(num_inference_steps, device=device)
477 timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) 462 timesteps = self.get_timesteps(num_inference_steps, strength, device)
478 463
479 # 5. Prepare latent variables 464 # 5. Prepare latent variables
480 num_channels_latents = self.unet.in_channels 465 if isinstance(image, PIL.Image.Image):
481 if latents_are_image:
482 image = preprocess(image) 466 image = preprocess(image)
483 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 467 elif image is None:
484 latents = self.prepare_latents_from_image( 468 image = self.prepare_image(
485 image,
486 latent_timestep,
487 batch_size * num_images_per_prompt, 469 batch_size * num_images_per_prompt,
488 prompt_embeds.dtype,
489 device,
490 generator
491 )
492 else:
493 latents = self.prepare_latents(
494 batch_size * num_images_per_prompt,
495 num_channels_latents,
496 height,
497 width, 470 width,
471 height,
498 prompt_embeds.dtype, 472 prompt_embeds.dtype,
499 device, 473 device,
500 generator, 474 generator
501 image,
502 ) 475 )
503 476
477 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
478 latents = self.prepare_latents(
479 image,
480 latent_timestep,
481 batch_size * num_images_per_prompt,
482 prompt_embeds.dtype,
483 device,
484 generator
485 )
486
504 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 487 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
505 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 488 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
506 489