diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py (renamed from pipelines/stable_diffusion/clip_guided_stable_diffusion.py) | 80 |
1 files changed, 64 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index eff74b5..4c793a8 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -2,22 +2,29 @@ import inspect | |||
2 | import warnings | 2 | import warnings |
3 | from typing import List, Optional, Union | 3 | from typing import List, Optional, Union |
4 | 4 | ||
5 | import numpy as np | ||
5 | import torch | 6 | import torch |
6 | from torch import nn | 7 | import PIL |
7 | from torch.nn import functional as F | ||
8 | 8 | ||
9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from torchvision import transforms | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | ||
15 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward |
16 | 15 | ||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
18 | 17 | ||
19 | 18 | ||
20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 19 | def preprocess(image, w, h): |
20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
21 | image = np.array(image).astype(np.float32) / 255.0 | ||
22 | image = image[None].transpose(0, 3, 1, 2) | ||
23 | image = torch.from_numpy(image) | ||
24 | return 2.0 * image - 1.0 | ||
25 | |||
26 | |||
27 | class VlpnStableDiffusion(DiffusionPipeline): | ||
21 | def __init__( | 28 | def __init__( |
22 | self, | 29 | self, |
23 | vae: AutoencoderKL, | 30 | vae: AutoencoderKL, |
@@ -83,13 +90,14 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
83 | self, | 90 | self, |
84 | prompt: Union[str, List[str]], | 91 | prompt: Union[str, List[str]], |
85 | negative_prompt: Optional[Union[str, List[str]]] = None, | 92 | negative_prompt: Optional[Union[str, List[str]]] = None, |
93 | strength: float = 0.8, | ||
86 | height: Optional[int] = 512, | 94 | height: Optional[int] = 512, |
87 | width: Optional[int] = 512, | 95 | width: Optional[int] = 512, |
88 | num_inference_steps: Optional[int] = 50, | 96 | num_inference_steps: Optional[int] = 50, |
89 | guidance_scale: Optional[float] = 7.5, | 97 | guidance_scale: Optional[float] = 7.5, |
90 | eta: Optional[float] = 0.0, | 98 | eta: Optional[float] = 0.0, |
91 | generator: Optional[torch.Generator] = None, | 99 | generator: Optional[torch.Generator] = None, |
92 | latents: Optional[torch.FloatTensor] = None, | 100 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
93 | output_type: Optional[str] = "pil", | 101 | output_type: Optional[str] = "pil", |
94 | return_dict: bool = True, | 102 | return_dict: bool = True, |
95 | ): | 103 | ): |
@@ -99,6 +107,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
99 | Args: | 107 | Args: |
100 | prompt (`str` or `List[str]`): | 108 | prompt (`str` or `List[str]`): |
101 | The prompt or prompts to guide the image generation. | 109 | The prompt or prompts to guide the image generation. |
110 | strength (`float`, *optional*, defaults to 0.8): | ||
111 | Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. | ||
112 | `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The | ||
113 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | ||
114 | noise will be maximum and the denoising process will run for the full number of iterations specified in | ||
115 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | ||
102 | height (`int`, *optional*, defaults to 512): | 116 | height (`int`, *optional*, defaults to 512): |
103 | The height in pixels of the generated image. | 117 | The height in pixels of the generated image. |
104 | width (`int`, *optional*, defaults to 512): | 118 | width (`int`, *optional*, defaults to 512): |
@@ -158,6 +172,42 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
158 | if height % 8 != 0 or width % 8 != 0: | 172 | if height % 8 != 0 or width % 8 != 0: |
159 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
160 | 174 | ||
175 | if strength < 0 or strength > 1: | ||
176 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
177 | |||
178 | # set timesteps | ||
179 | self.scheduler.set_timesteps(num_inference_steps) | ||
180 | |||
181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
182 | |||
183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
184 | latents = preprocess(latents, width, height) | ||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
186 | latents = latent_dist.sample(generator=generator) | ||
187 | latents = 0.18215 * latents | ||
188 | latents = torch.cat([latents] * batch_size) | ||
189 | |||
190 | # get the original timestep using init_timestep | ||
191 | init_timestep = int(num_inference_steps * strength) + offset | ||
192 | init_timestep = min(init_timestep, num_inference_steps) | ||
193 | |||
194 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
195 | timesteps = torch.tensor( | ||
196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
197 | ) | ||
198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
201 | else: | ||
202 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
204 | |||
205 | # add noise to latents using the timesteps | ||
206 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
207 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
208 | else: | ||
209 | init_timestep = num_inference_steps + offset | ||
210 | |||
161 | # get prompt text embeddings | 211 | # get prompt text embeddings |
162 | text_inputs = self.tokenizer( | 212 | text_inputs = self.tokenizer( |
163 | prompt, | 213 | prompt, |
@@ -213,15 +263,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 263 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
214 | latents = latents.to(self.device) | 264 | latents = latents.to(self.device) |
215 | 265 | ||
216 | # set timesteps | 266 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
217 | self.scheduler.set_timesteps(num_inference_steps) | ||
218 | 267 | ||
219 | # Some schedulers like PNDM have timesteps as arrays | 268 | # Some schedulers like PNDM have timesteps as arrays |
220 | # It's more optimzed to move all timesteps to correct device beforehand | 269 | # It's more optimzed to move all timesteps to correct device beforehand |
221 | if torch.is_tensor(self.scheduler.timesteps): | 270 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
222 | timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
223 | else: | ||
224 | timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
225 | 271 | ||
226 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 272 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
227 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 273 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
@@ -244,10 +290,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
244 | extra_step_kwargs["generator"] = generator | 290 | extra_step_kwargs["generator"] = generator |
245 | 291 | ||
246 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 292 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
293 | t_index = t_start + i | ||
294 | |||
247 | # expand the latents if we are doing classifier free guidance | 295 | # expand the latents if we are doing classifier free guidance |
248 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 296 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
249 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 297 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
250 | sigma = self.scheduler.sigmas[i] | 298 | sigma = self.scheduler.sigmas[t_index] |
251 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 299 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
252 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | 300 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) |
253 | 301 | ||
@@ -270,10 +318,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
270 | 318 | ||
271 | # compute the previous noisy sample x_t -> x_t-1 | 319 | # compute the previous noisy sample x_t -> x_t-1 |
272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 320 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
274 | elif isinstance(self.scheduler, EulerAScheduler): | 322 | elif isinstance(self.scheduler, EulerAScheduler): |
275 | if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error |
276 | t_prev = self.scheduler.timesteps[i+1] | 324 | t_prev = self.scheduler.timesteps[t_index+1] |
277 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | 325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample |
278 | else: | 326 | else: |
279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |