diff options
author | Volpeon <git@volpeon.ink> | 2023-03-04 08:17:31 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-04 08:17:31 +0100 |
commit | 5b80eb8dac50941c05209df9bb560959ab81bdb0 (patch) | |
tree | 30a4902ee7b6a794b298addb1247d0daa4ecdf42 /pipelines | |
parent | Changed init noise algorithm (diff) | |
download | textual-inversion-diff-5b80eb8dac50941c05209df9bb560959ab81bdb0.tar.gz textual-inversion-diff-5b80eb8dac50941c05209df9bb560959ab81bdb0.tar.bz2 textual-inversion-diff-5b80eb8dac50941c05209df9bb560959ab81bdb0.zip |
Pipeline: Improved initial image generation
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 242be29..2251848 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -295,16 +295,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
295 | 295 | ||
296 | def get_timesteps(self, num_inference_steps, strength, device): | 296 | def get_timesteps(self, num_inference_steps, strength, device): |
297 | # get the original timestep using init_timestep | 297 | # get the original timestep using init_timestep |
298 | offset = self.scheduler.config.get("steps_offset", 0) | 298 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
299 | init_timestep = int(num_inference_steps * strength) + offset | ||
300 | init_timestep = min(init_timestep, num_inference_steps) | ||
301 | 299 | ||
302 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 300 | t_start = max(num_inference_steps - init_timestep, 0) |
303 | timesteps = self.scheduler.timesteps[t_start:] | 301 | timesteps = self.scheduler.timesteps[t_start:] |
304 | 302 | ||
305 | timesteps = timesteps.to(device) | 303 | timesteps = timesteps.to(device) |
306 | 304 | ||
307 | return timesteps | 305 | return timesteps, num_inference_steps - t_start |
308 | 306 | ||
309 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): | 307 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): |
310 | offset = (max_offset * (2 * torch.rand( | 308 | offset = (max_offset * (2 * torch.rand( |
@@ -312,12 +310,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
312 | dtype=dtype, | 310 | dtype=dtype, |
313 | device=device, | 311 | device=device, |
314 | generator=generator | 312 | generator=generator |
315 | ) - 1)).expand(batch_size, 3, width, height) | 313 | ) - 1)).expand(batch_size, 1, 2, 2) |
316 | image = (.1 * torch.normal( | 314 | image = F.interpolate( |
317 | mean=offset, | 315 | torch.normal( |
318 | std=1, | 316 | mean=offset, |
319 | generator=generator | 317 | std=0.3, |
320 | )).clamp(-1, 1) | 318 | generator=generator |
319 | ).clamp(-1, 1), | ||
320 | size=(width, height), | ||
321 | mode="bicubic" | ||
322 | ).expand(batch_size, 3, width, height) | ||
321 | return image | 323 | return image |
322 | 324 | ||
323 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 325 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
@@ -382,7 +384,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
382 | eta: float = 0.0, | 384 | eta: float = 0.0, |
383 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 385 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
384 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 386 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
385 | max_image_offset: float = 1.0, | 387 | max_init_offset: float = 0.7, |
386 | output_type: str = "pil", | 388 | output_type: str = "pil", |
387 | return_dict: bool = True, | 389 | return_dict: bool = True, |
388 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 390 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -464,11 +466,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
464 | device | 466 | device |
465 | ) | 467 | ) |
466 | 468 | ||
467 | # 4. Prepare timesteps | 469 | # 4. Prepare latent variables |
468 | self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
469 | timesteps = self.get_timesteps(num_inference_steps, strength, device) | ||
470 | |||
471 | # 5. Prepare latent variables | ||
472 | if isinstance(image, PIL.Image.Image): | 470 | if isinstance(image, PIL.Image.Image): |
473 | image = preprocess(image) | 471 | image = preprocess(image) |
474 | elif image is None: | 472 | elif image is None: |
@@ -476,13 +474,18 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
476 | batch_size * num_images_per_prompt, | 474 | batch_size * num_images_per_prompt, |
477 | width, | 475 | width, |
478 | height, | 476 | height, |
479 | max_image_offset, | 477 | max_init_offset, |
480 | prompt_embeds.dtype, | 478 | prompt_embeds.dtype, |
481 | device, | 479 | device, |
482 | generator | 480 | generator |
483 | ) | 481 | ) |
484 | 482 | ||
483 | # 5. Prepare timesteps | ||
484 | self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
485 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | ||
485 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 486 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
487 | |||
488 | # 6. Prepare latent variables | ||
486 | latents = self.prepare_latents( | 489 | latents = self.prepare_latents( |
487 | image, | 490 | image, |
488 | latent_timestep, | 491 | latent_timestep, |
@@ -492,10 +495,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
492 | generator | 495 | generator |
493 | ) | 496 | ) |
494 | 497 | ||
495 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 498 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
496 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 499 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
497 | 500 | ||
498 | # 7. Denoising loo | 501 | # 8. Denoising loo |
499 | if do_self_attention_guidance: | 502 | if do_self_attention_guidance: |
500 | store_processor = CrossAttnStoreProcessor() | 503 | store_processor = CrossAttnStoreProcessor() |
501 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | 504 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor |
@@ -559,13 +562,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
559 | if callback is not None and i % callback_steps == 0: | 562 | if callback is not None and i % callback_steps == 0: |
560 | callback(i, t, latents) | 563 | callback(i, t, latents) |
561 | 564 | ||
562 | # 8. Post-processing | 565 | # 9. Post-processing |
563 | image = self.decode_latents(latents) | 566 | image = self.decode_latents(latents) |
564 | 567 | ||
565 | # 9. Run safety checker | 568 | # 10. Run safety checker |
566 | has_nsfw_concept = None | 569 | has_nsfw_concept = None |
567 | 570 | ||
568 | # 10. Convert to PIL | 571 | # 11. Convert to PIL |
569 | if output_type == "pil": | 572 | if output_type == "pil": |
570 | image = self.numpy_to_pil(image) | 573 | image = self.numpy_to_pil(image) |
571 | 574 | ||