diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py (renamed from pipelines/stable_diffusion/clip_guided_stable_diffusion.py) | 80 |
1 files changed, 64 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index eff74b5..4c793a8 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -2,22 +2,29 @@ import inspect | |||
| 2 | import warnings | 2 | import warnings |
| 3 | from typing import List, Optional, Union | 3 | from typing import List, Optional, Union |
| 4 | 4 | ||
| 5 | import numpy as np | ||
| 5 | import torch | 6 | import torch |
| 6 | from torch import nn | 7 | import PIL |
| 7 | from torch.nn import functional as F | ||
| 8 | 8 | ||
| 9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
| 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
| 13 | from torchvision import transforms | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | ||
| 15 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward |
| 16 | 15 | ||
| 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 18 | 17 | ||
| 19 | 18 | ||
| 20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 19 | def preprocess(image, w, h): |
| 20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
| 21 | image = np.array(image).astype(np.float32) / 255.0 | ||
| 22 | image = image[None].transpose(0, 3, 1, 2) | ||
| 23 | image = torch.from_numpy(image) | ||
| 24 | return 2.0 * image - 1.0 | ||
| 25 | |||
| 26 | |||
| 27 | class VlpnStableDiffusion(DiffusionPipeline): | ||
| 21 | def __init__( | 28 | def __init__( |
| 22 | self, | 29 | self, |
| 23 | vae: AutoencoderKL, | 30 | vae: AutoencoderKL, |
| @@ -83,13 +90,14 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 83 | self, | 90 | self, |
| 84 | prompt: Union[str, List[str]], | 91 | prompt: Union[str, List[str]], |
| 85 | negative_prompt: Optional[Union[str, List[str]]] = None, | 92 | negative_prompt: Optional[Union[str, List[str]]] = None, |
| 93 | strength: float = 0.8, | ||
| 86 | height: Optional[int] = 512, | 94 | height: Optional[int] = 512, |
| 87 | width: Optional[int] = 512, | 95 | width: Optional[int] = 512, |
| 88 | num_inference_steps: Optional[int] = 50, | 96 | num_inference_steps: Optional[int] = 50, |
| 89 | guidance_scale: Optional[float] = 7.5, | 97 | guidance_scale: Optional[float] = 7.5, |
| 90 | eta: Optional[float] = 0.0, | 98 | eta: Optional[float] = 0.0, |
| 91 | generator: Optional[torch.Generator] = None, | 99 | generator: Optional[torch.Generator] = None, |
| 92 | latents: Optional[torch.FloatTensor] = None, | 100 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 93 | output_type: Optional[str] = "pil", | 101 | output_type: Optional[str] = "pil", |
| 94 | return_dict: bool = True, | 102 | return_dict: bool = True, |
| 95 | ): | 103 | ): |
| @@ -99,6 +107,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 99 | Args: | 107 | Args: |
| 100 | prompt (`str` or `List[str]`): | 108 | prompt (`str` or `List[str]`): |
| 101 | The prompt or prompts to guide the image generation. | 109 | The prompt or prompts to guide the image generation. |
| 110 | strength (`float`, *optional*, defaults to 0.8): | ||
| 111 | Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. | ||
| 112 | `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The | ||
| 113 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | ||
| 114 | noise will be maximum and the denoising process will run for the full number of iterations specified in | ||
| 115 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | ||
| 102 | height (`int`, *optional*, defaults to 512): | 116 | height (`int`, *optional*, defaults to 512): |
| 103 | The height in pixels of the generated image. | 117 | The height in pixels of the generated image. |
| 104 | width (`int`, *optional*, defaults to 512): | 118 | width (`int`, *optional*, defaults to 512): |
| @@ -158,6 +172,42 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 158 | if height % 8 != 0 or width % 8 != 0: | 172 | if height % 8 != 0 or width % 8 != 0: |
| 159 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| 160 | 174 | ||
| 175 | if strength < 0 or strength > 1: | ||
| 176 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
| 177 | |||
| 178 | # set timesteps | ||
| 179 | self.scheduler.set_timesteps(num_inference_steps) | ||
| 180 | |||
| 181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 182 | |||
| 183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
| 184 | latents = preprocess(latents, width, height) | ||
| 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
| 186 | latents = latent_dist.sample(generator=generator) | ||
| 187 | latents = 0.18215 * latents | ||
| 188 | latents = torch.cat([latents] * batch_size) | ||
| 189 | |||
| 190 | # get the original timestep using init_timestep | ||
| 191 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 192 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 193 | |||
| 194 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 195 | timesteps = torch.tensor( | ||
| 196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 197 | ) | ||
| 198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 201 | else: | ||
| 202 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
| 204 | |||
| 205 | # add noise to latents using the timesteps | ||
| 206 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
| 207 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
| 208 | else: | ||
| 209 | init_timestep = num_inference_steps + offset | ||
| 210 | |||
| 161 | # get prompt text embeddings | 211 | # get prompt text embeddings |
| 162 | text_inputs = self.tokenizer( | 212 | text_inputs = self.tokenizer( |
| 163 | prompt, | 213 | prompt, |
| @@ -213,15 +263,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 263 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
| 214 | latents = latents.to(self.device) | 264 | latents = latents.to(self.device) |
| 215 | 265 | ||
| 216 | # set timesteps | 266 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
| 217 | self.scheduler.set_timesteps(num_inference_steps) | ||
| 218 | 267 | ||
| 219 | # Some schedulers like PNDM have timesteps as arrays | 268 | # Some schedulers like PNDM have timesteps as arrays |
| 220 | # It's more optimzed to move all timesteps to correct device beforehand | 269 | # It's more optimzed to move all timesteps to correct device beforehand |
| 221 | if torch.is_tensor(self.scheduler.timesteps): | 270 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
| 222 | timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
| 223 | else: | ||
| 224 | timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
| 225 | 271 | ||
| 226 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 272 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
| 227 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 273 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| @@ -244,10 +290,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 244 | extra_step_kwargs["generator"] = generator | 290 | extra_step_kwargs["generator"] = generator |
| 245 | 291 | ||
| 246 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 292 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
| 293 | t_index = t_start + i | ||
| 294 | |||
| 247 | # expand the latents if we are doing classifier free guidance | 295 | # expand the latents if we are doing classifier free guidance |
| 248 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 296 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 249 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 297 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 250 | sigma = self.scheduler.sigmas[i] | 298 | sigma = self.scheduler.sigmas[t_index] |
| 251 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 299 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
| 252 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | 300 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) |
| 253 | 301 | ||
| @@ -270,10 +318,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 270 | 318 | ||
| 271 | # compute the previous noisy sample x_t -> x_t-1 | 319 | # compute the previous noisy sample x_t -> x_t-1 |
| 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 320 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
| 274 | elif isinstance(self.scheduler, EulerAScheduler): | 322 | elif isinstance(self.scheduler, EulerAScheduler): |
| 275 | if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error |
| 276 | t_prev = self.scheduler.timesteps[i+1] | 324 | t_prev = self.scheduler.timesteps[t_index+1] |
| 277 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | 325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample |
| 278 | else: | 326 | else: |
| 279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
