summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-25 16:34:48 +0100
committerVolpeon <git@volpeon.ink>2023-03-25 16:34:48 +0100
commit6b8a93f46f053668c8023520225a18445d48d8f1 (patch)
tree463c8835a9a90dd9b5586a13e55d6882caa3103a /pipelines/stable_diffusion
parentUpdate (diff)
downloadtextual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz
textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2
textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py61
1 files changed, 33 insertions, 28 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index ea2a656..127ca50 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline):
307 307
308 return timesteps, num_inference_steps - t_start 308 return timesteps, num_inference_steps - t_start
309 309
310 def prepare_image(self, batch_size, width, height, dtype, device, generator=None): 310 def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None):
311 return (1.4 * perlin_noise( 311 offset_image = perlin_noise(
312 (batch_size, 1, width, height), 312 (batch_size, 1, width, height),
313 res=1, 313 res=1,
314 octaves=4,
315 generator=generator, 314 generator=generator,
316 dtype=dtype, 315 dtype=dtype,
317 device=device 316 device=device
318 )).clamp(-1, 1).expand(batch_size, 3, width, height) 317 )
318 offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator)
319 offset_latents = self.vae.config.scaling_factor * offset_latents
320 return offset_latents
319 321
320 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): 322 def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None):
321 init_image = init_image.to(device=device, dtype=dtype) 323 init_image = init_image.to(device=device, dtype=dtype)
322 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 324 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
323 init_latents = self.vae.config.scaling_factor * init_latents 325 latents = self.vae.config.scaling_factor * latents
324 326
325 if batch_size % init_latents.shape[0] != 0: 327 if batch_size % latents.shape[0] != 0:
326 raise ValueError( 328 raise ValueError(
327 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 329 f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts."
328 ) 330 )
329 else: 331 else:
330 batch_multiplier = batch_size // init_latents.shape[0] 332 batch_multiplier = batch_size // latents.shape[0]
331 init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) 333 latents = torch.cat([latents] * batch_multiplier, dim=0)
332 334
333 # add noise to latents using the timesteps 335 # add noise to latents using the timesteps
334 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 336 noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype)
337
338 if brightness_offset != 0:
339 noise += brightness_offset * self.prepare_brightness_offset(
340 batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator
341 )
335 342
336 # get latents 343 # get latents
337 init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 344 latents = self.scheduler.add_noise(latents, noise, timestep)
338 latents = init_latents
339 345
340 return latents 346 return latents
341 347
342 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 348 def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None):
343 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 349 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
344 if isinstance(generator, list) and len(generator) != batch_size: 350 if isinstance(generator, list) and len(generator) != batch_size:
345 raise ValueError( 351 raise ValueError(
@@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
352 else: 358 else:
353 latents = latents.to(device) 359 latents = latents.to(device)
354 360
361 if brightness_offset != 0:
362 latents += brightness_offset * self.prepare_brightness_offset(
363 batch_size, height, width, dtype, device, generator
364 )
365
355 # scale the initial noise by the standard deviation required by the scheduler 366 # scale the initial noise by the standard deviation required by the scheduler
356 latents = latents * self.scheduler.init_noise_sigma 367 latents = latents * self.scheduler.init_noise_sigma
357 return latents 368 return latents
@@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
395 sag_scale: float = 0.75, 406 sag_scale: float = 0.75,
396 eta: float = 0.0, 407 eta: float = 0.0,
397 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
398 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, 409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
410 brightness_offset: Union[float, torch.FloatTensor] = 0,
399 output_type: str = "pil", 411 output_type: str = "pil",
400 return_dict: bool = True, 412 return_dict: bool = True,
401 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 413 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
468 num_channels_latents = self.unet.in_channels 480 num_channels_latents = self.unet.in_channels
469 do_classifier_free_guidance = guidance_scale > 1.0 481 do_classifier_free_guidance = guidance_scale > 1.0
470 do_self_attention_guidance = sag_scale > 0.0 482 do_self_attention_guidance = sag_scale > 0.0
471 prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" 483 prep_from_image = isinstance(image, PIL.Image.Image)
472 484
473 # 3. Encode input prompt 485 # 3. Encode input prompt
474 prompt_embeds = self.encode_prompt( 486 prompt_embeds = self.encode_prompt(
@@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
482 # 4. Prepare latent variables 494 # 4. Prepare latent variables
483 if isinstance(image, PIL.Image.Image): 495 if isinstance(image, PIL.Image.Image):
484 image = preprocess(image) 496 image = preprocess(image)
485 elif image == "noise":
486 image = self.prepare_image(
487 batch_size * num_images_per_prompt,
488 width,
489 height,
490 prompt_embeds.dtype,
491 device,
492 generator
493 )
494 497
495 # 5. Prepare timesteps 498 # 5. Prepare timesteps
496 self.scheduler.set_timesteps(num_inference_steps, device=device) 499 self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
503 image, 506 image,
504 latent_timestep, 507 latent_timestep,
505 batch_size * num_images_per_prompt, 508 batch_size * num_images_per_prompt,
509 brightness_offset,
506 prompt_embeds.dtype, 510 prompt_embeds.dtype,
507 device, 511 device,
508 generator 512 generator,
509 ) 513 )
510 else: 514 else:
511 latents = self.prepare_latents( 515 latents = self.prepare_latents(
@@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
513 num_channels_latents, 517 num_channels_latents,
514 height, 518 height,
515 width, 519 width,
520 brightness_offset,
516 prompt_embeds.dtype, 521 prompt_embeds.dtype,
517 device, 522 device,
518 generator, 523 generator,
519 image 524 image,
520 ) 525 )
521 526
522 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 527 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline