diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 61 |
1 files changed, 33 insertions, 28 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ea2a656..127ca50 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
307 | 307 | ||
308 | return timesteps, num_inference_steps - t_start | 308 | return timesteps, num_inference_steps - t_start |
309 | 309 | ||
310 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): |
311 | return (1.4 * perlin_noise( | 311 | offset_image = perlin_noise( |
312 | (batch_size, 1, width, height), | 312 | (batch_size, 1, width, height), |
313 | res=1, | 313 | res=1, |
314 | octaves=4, | ||
315 | generator=generator, | 314 | generator=generator, |
316 | dtype=dtype, | 315 | dtype=dtype, |
317 | device=device | 316 | device=device |
318 | )).clamp(-1, 1).expand(batch_size, 3, width, height) | 317 | ) |
318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | ||
319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
320 | return offset_latents | ||
319 | 321 | ||
320 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): |
321 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
322 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
323 | init_latents = self.vae.config.scaling_factor * init_latents | 325 | latents = self.vae.config.scaling_factor * latents |
324 | 326 | ||
325 | if batch_size % init_latents.shape[0] != 0: | 327 | if batch_size % latents.shape[0] != 0: |
326 | raise ValueError( | 328 | raise ValueError( |
327 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 329 | f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts." |
328 | ) | 330 | ) |
329 | else: | 331 | else: |
330 | batch_multiplier = batch_size // init_latents.shape[0] | 332 | batch_multiplier = batch_size // latents.shape[0] |
331 | init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) | 333 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
332 | 334 | ||
333 | # add noise to latents using the timesteps | 335 | # add noise to latents using the timesteps |
334 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) |
337 | |||
338 | if brightness_offset != 0: | ||
339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
341 | ) | ||
335 | 342 | ||
336 | # get latents | 343 | # get latents |
337 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | 344 | latents = self.scheduler.add_noise(latents, noise, timestep) |
338 | latents = init_latents | ||
339 | 345 | ||
340 | return latents | 346 | return latents |
341 | 347 | ||
342 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | 348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): |
343 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
344 | if isinstance(generator, list) and len(generator) != batch_size: | 350 | if isinstance(generator, list) and len(generator) != batch_size: |
345 | raise ValueError( | 351 | raise ValueError( |
@@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
352 | else: | 358 | else: |
353 | latents = latents.to(device) | 359 | latents = latents.to(device) |
354 | 360 | ||
361 | if brightness_offset != 0: | ||
362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
363 | batch_size, height, width, dtype, device, generator | ||
364 | ) | ||
365 | |||
355 | # scale the initial noise by the standard deviation required by the scheduler | 366 | # scale the initial noise by the standard deviation required by the scheduler |
356 | latents = latents * self.scheduler.init_noise_sigma | 367 | latents = latents * self.scheduler.init_noise_sigma |
357 | return latents | 368 | return latents |
@@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
395 | sag_scale: float = 0.75, | 406 | sag_scale: float = 0.75, |
396 | eta: float = 0.0, | 407 | eta: float = 0.0, |
397 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
398 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, | 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
399 | output_type: str = "pil", | 411 | output_type: str = "pil", |
400 | return_dict: bool = True, | 412 | return_dict: bool = True, |
401 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
468 | num_channels_latents = self.unet.in_channels | 480 | num_channels_latents = self.unet.in_channels |
469 | do_classifier_free_guidance = guidance_scale > 1.0 | 481 | do_classifier_free_guidance = guidance_scale > 1.0 |
470 | do_self_attention_guidance = sag_scale > 0.0 | 482 | do_self_attention_guidance = sag_scale > 0.0 |
471 | prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" | 483 | prep_from_image = isinstance(image, PIL.Image.Image) |
472 | 484 | ||
473 | # 3. Encode input prompt | 485 | # 3. Encode input prompt |
474 | prompt_embeds = self.encode_prompt( | 486 | prompt_embeds = self.encode_prompt( |
@@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
482 | # 4. Prepare latent variables | 494 | # 4. Prepare latent variables |
483 | if isinstance(image, PIL.Image.Image): | 495 | if isinstance(image, PIL.Image.Image): |
484 | image = preprocess(image) | 496 | image = preprocess(image) |
485 | elif image == "noise": | ||
486 | image = self.prepare_image( | ||
487 | batch_size * num_images_per_prompt, | ||
488 | width, | ||
489 | height, | ||
490 | prompt_embeds.dtype, | ||
491 | device, | ||
492 | generator | ||
493 | ) | ||
494 | 497 | ||
495 | # 5. Prepare timesteps | 498 | # 5. Prepare timesteps |
496 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
@@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
503 | image, | 506 | image, |
504 | latent_timestep, | 507 | latent_timestep, |
505 | batch_size * num_images_per_prompt, | 508 | batch_size * num_images_per_prompt, |
509 | brightness_offset, | ||
506 | prompt_embeds.dtype, | 510 | prompt_embeds.dtype, |
507 | device, | 511 | device, |
508 | generator | 512 | generator, |
509 | ) | 513 | ) |
510 | else: | 514 | else: |
511 | latents = self.prepare_latents( | 515 | latents = self.prepare_latents( |
@@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
513 | num_channels_latents, | 517 | num_channels_latents, |
514 | height, | 518 | height, |
515 | width, | 519 | width, |
520 | brightness_offset, | ||
516 | prompt_embeds.dtype, | 521 | prompt_embeds.dtype, |
517 | device, | 522 | device, |
518 | generator, | 523 | generator, |
519 | image | 524 | image, |
520 | ) | 525 | ) |
521 | 526 | ||
522 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 527 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |