diff options
author | Volpeon <git@volpeon.ink> | 2023-03-04 19:24:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-04 19:24:24 +0100 |
commit | bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c (patch) | |
tree | 88505e6fb13666ba459577935151aab43ee019d2 /pipelines | |
parent | Added Perlin noise to training (diff) | |
download | textual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.tar.gz textual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.tar.bz2 textual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.zip |
More flexible pipeline wrt init noise
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 5f4fc38..f27be78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -1,7 +1,7 @@ | |||
1 | import inspect | 1 | import inspect |
2 | import warnings | 2 | import warnings |
3 | import math | 3 | import math |
4 | from typing import List, Dict, Any, Optional, Union, Callable | 4 | from typing import List, Dict, Any, Optional, Union, Callable, Literal |
5 | 5 | ||
6 | import numpy as np | 6 | import numpy as np |
7 | import torch | 7 | import torch |
@@ -22,7 +22,7 @@ from diffusers import ( | |||
22 | PNDMScheduler, | 22 | PNDMScheduler, |
23 | ) | 23 | ) |
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
25 | from diffusers.utils import logging | 25 | from diffusers.utils import logging, randn_tensor |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | 27 | ||
28 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
@@ -312,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
312 | ).expand(batch_size, 3, width, height) | 312 | ).expand(batch_size, 3, width, height) |
313 | return (1.4 * noise).clamp(-1, 1) | 313 | return (1.4 * noise).clamp(-1, 1) |
314 | 314 | ||
315 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 315 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
316 | init_image = init_image.to(device=device, dtype=dtype) | 316 | init_image = init_image.to(device=device, dtype=dtype) |
317 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 317 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
318 | init_latents = self.vae.config.scaling_factor * init_latents | 318 | init_latents = self.vae.config.scaling_factor * init_latents |
@@ -334,6 +334,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
334 | 334 | ||
335 | return latents | 335 | return latents |
336 | 336 | ||
337 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | ||
338 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | ||
339 | if isinstance(generator, list) and len(generator) != batch_size: | ||
340 | raise ValueError( | ||
341 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
342 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
343 | ) | ||
344 | |||
345 | if latents is None: | ||
346 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
347 | else: | ||
348 | latents = latents.to(device) | ||
349 | |||
350 | # scale the initial noise by the standard deviation required by the scheduler | ||
351 | latents = latents * self.scheduler.init_noise_sigma | ||
352 | return latents | ||
353 | |||
337 | def prepare_extra_step_kwargs(self, generator, eta): | 354 | def prepare_extra_step_kwargs(self, generator, eta): |
338 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 355 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
339 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 356 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
@@ -373,7 +390,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | sag_scale: float = 0.75, | 390 | sag_scale: float = 0.75, |
374 | eta: float = 0.0, | 391 | eta: float = 0.0, |
375 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 392 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
376 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 393 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, |
377 | output_type: str = "pil", | 394 | output_type: str = "pil", |
378 | return_dict: bool = True, | 395 | return_dict: bool = True, |
379 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 396 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -443,8 +460,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
443 | # 2. Define call parameters | 460 | # 2. Define call parameters |
444 | batch_size = len(prompt) | 461 | batch_size = len(prompt) |
445 | device = self.execution_device | 462 | device = self.execution_device |
463 | num_channels_latents = self.unet.in_channels | ||
446 | do_classifier_free_guidance = guidance_scale > 1.0 | 464 | do_classifier_free_guidance = guidance_scale > 1.0 |
447 | do_self_attention_guidance = sag_scale > 0.0 | 465 | do_self_attention_guidance = sag_scale > 0.0 |
466 | prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" | ||
448 | 467 | ||
449 | # 3. Encode input prompt | 468 | # 3. Encode input prompt |
450 | prompt_embeds = self.encode_prompt( | 469 | prompt_embeds = self.encode_prompt( |
@@ -458,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
458 | # 4. Prepare latent variables | 477 | # 4. Prepare latent variables |
459 | if isinstance(image, PIL.Image.Image): | 478 | if isinstance(image, PIL.Image.Image): |
460 | image = preprocess(image) | 479 | image = preprocess(image) |
461 | elif image is None: | 480 | elif image == "noise": |
462 | image = self.prepare_image( | 481 | image = self.prepare_image( |
463 | batch_size * num_images_per_prompt, | 482 | batch_size * num_images_per_prompt, |
464 | width, | 483 | width, |
@@ -474,14 +493,26 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
474 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 493 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
475 | 494 | ||
476 | # 6. Prepare latent variables | 495 | # 6. Prepare latent variables |
477 | latents = self.prepare_latents( | 496 | if prep_from_image: |
478 | image, | 497 | latents = self.prepare_latents_from_image( |
479 | latent_timestep, | 498 | image, |
480 | batch_size * num_images_per_prompt, | 499 | latent_timestep, |
481 | prompt_embeds.dtype, | 500 | batch_size * num_images_per_prompt, |
482 | device, | 501 | prompt_embeds.dtype, |
483 | generator | 502 | device, |
484 | ) | 503 | generator |
504 | ) | ||
505 | else: | ||
506 | latents = self.prepare_latents( | ||
507 | batch_size, | ||
508 | num_channels_latents, | ||
509 | height, | ||
510 | width, | ||
511 | prompt_embeds.dtype, | ||
512 | device, | ||
513 | generator, | ||
514 | image | ||
515 | ) | ||
485 | 516 | ||
486 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 517 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
487 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 518 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |