summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py49
1 files changed, 26 insertions, 23 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 242be29..2251848 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -295,16 +295,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
295 295
296 def get_timesteps(self, num_inference_steps, strength, device): 296 def get_timesteps(self, num_inference_steps, strength, device):
297 # get the original timestep using init_timestep 297 # get the original timestep using init_timestep
298 offset = self.scheduler.config.get("steps_offset", 0) 298 init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
299 init_timestep = int(num_inference_steps * strength) + offset
300 init_timestep = min(init_timestep, num_inference_steps)
301 299
302 t_start = max(num_inference_steps - init_timestep + offset, 0) 300 t_start = max(num_inference_steps - init_timestep, 0)
303 timesteps = self.scheduler.timesteps[t_start:] 301 timesteps = self.scheduler.timesteps[t_start:]
304 302
305 timesteps = timesteps.to(device) 303 timesteps = timesteps.to(device)
306 304
307 return timesteps 305 return timesteps, num_inference_steps - t_start
308 306
309 def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): 307 def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None):
310 offset = (max_offset * (2 * torch.rand( 308 offset = (max_offset * (2 * torch.rand(
@@ -312,12 +310,16 @@ class VlpnStableDiffusion(DiffusionPipeline):
312 dtype=dtype, 310 dtype=dtype,
313 device=device, 311 device=device,
314 generator=generator 312 generator=generator
315 ) - 1)).expand(batch_size, 3, width, height) 313 ) - 1)).expand(batch_size, 1, 2, 2)
316 image = (.1 * torch.normal( 314 image = F.interpolate(
317 mean=offset, 315 torch.normal(
318 std=1, 316 mean=offset,
319 generator=generator 317 std=0.3,
320 )).clamp(-1, 1) 318 generator=generator
319 ).clamp(-1, 1),
320 size=(width, height),
321 mode="bicubic"
322 ).expand(batch_size, 3, width, height)
321 return image 323 return image
322 324
323 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): 325 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None):
@@ -382,7 +384,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
382 eta: float = 0.0, 384 eta: float = 0.0,
383 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 385 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
384 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 386 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
385 max_image_offset: float = 1.0, 387 max_init_offset: float = 0.7,
386 output_type: str = "pil", 388 output_type: str = "pil",
387 return_dict: bool = True, 389 return_dict: bool = True,
388 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 390 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -464,11 +466,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
464 device 466 device
465 ) 467 )
466 468
467 # 4. Prepare timesteps 469 # 4. Prepare latent variables
468 self.scheduler.set_timesteps(num_inference_steps, device=device)
469 timesteps = self.get_timesteps(num_inference_steps, strength, device)
470
471 # 5. Prepare latent variables
472 if isinstance(image, PIL.Image.Image): 470 if isinstance(image, PIL.Image.Image):
473 image = preprocess(image) 471 image = preprocess(image)
474 elif image is None: 472 elif image is None:
@@ -476,13 +474,18 @@ class VlpnStableDiffusion(DiffusionPipeline):
476 batch_size * num_images_per_prompt, 474 batch_size * num_images_per_prompt,
477 width, 475 width,
478 height, 476 height,
479 max_image_offset, 477 max_init_offset,
480 prompt_embeds.dtype, 478 prompt_embeds.dtype,
481 device, 479 device,
482 generator 480 generator
483 ) 481 )
484 482
483 # 5. Prepare timesteps
484 self.scheduler.set_timesteps(num_inference_steps, device=device)
485 timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
485 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 486 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
487
488 # 6. Prepare latent variables
486 latents = self.prepare_latents( 489 latents = self.prepare_latents(
487 image, 490 image,
488 latent_timestep, 491 latent_timestep,
@@ -492,10 +495,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
492 generator 495 generator
493 ) 496 )
494 497
495 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 498 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
496 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 499 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
497 500
498 # 7. Denoising loo 501 # 8. Denoising loo
499 if do_self_attention_guidance: 502 if do_self_attention_guidance:
500 store_processor = CrossAttnStoreProcessor() 503 store_processor = CrossAttnStoreProcessor()
501 self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor 504 self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
@@ -559,13 +562,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
559 if callback is not None and i % callback_steps == 0: 562 if callback is not None and i % callback_steps == 0:
560 callback(i, t, latents) 563 callback(i, t, latents)
561 564
562 # 8. Post-processing 565 # 9. Post-processing
563 image = self.decode_latents(latents) 566 image = self.decode_latents(latents)
564 567
565 # 9. Run safety checker 568 # 10. Run safety checker
566 has_nsfw_concept = None 569 has_nsfw_concept = None
567 570
568 # 10. Convert to PIL 571 # 11. Convert to PIL
569 if output_type == "pil": 572 if output_type == "pil":
570 image = self.numpy_to_pil(image) 573 image = self.numpy_to_pil(image)
571 574