From bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 19:24:24 +0100 Subject: More flexible pipeline wrt init noise --- .../stable_diffusion/vlpn_stable_diffusion.py | 57 +++++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 5f4fc38..f27be78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -1,7 +1,7 @@ import inspect import warnings import math -from typing import List, Dict, Any, Optional, Union, Callable +from typing import List, Dict, Any, Optional, Union, Callable, Literal import numpy as np import torch @@ -22,7 +22,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import logging +from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer from models.clip.util import unify_input_ids, get_extended_embeddings @@ -312,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ).expand(batch_size, 3, width, height) return (1.4 * noise).clamp(-1, 1) - def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): + def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) init_latents = self.vae.config.scaling_factor * init_latents @@ -334,6 +334,23 @@ class VlpnStableDiffusion(DiffusionPipeline): return latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -373,7 +390,7 @@ class VlpnStableDiffusion(DiffusionPipeline): sag_scale: float = 0.75, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -443,8 +460,10 @@ class VlpnStableDiffusion(DiffusionPipeline): # 2. Define call parameters batch_size = len(prompt) device = self.execution_device + num_channels_latents = self.unet.in_channels do_classifier_free_guidance = guidance_scale > 1.0 do_self_attention_guidance = sag_scale > 0.0 + prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" # 3. Encode input prompt prompt_embeds = self.encode_prompt( @@ -458,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): # 4. Prepare latent variables if isinstance(image, PIL.Image.Image): image = preprocess(image) - elif image is None: + elif image == "noise": image = self.prepare_image( batch_size * num_images_per_prompt, width, @@ -474,14 +493,26 @@ class VlpnStableDiffusion(DiffusionPipeline): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size * num_images_per_prompt, - prompt_embeds.dtype, - device, - generator - ) + if prep_from_image: + latents = self.prepare_latents_from_image( + image, + latent_timestep, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator + ) + else: + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + image + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) -- cgit v1.2.3-70-g09d2