summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
commit49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch)
tree3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /pipelines
parentFix seed, better progress bar, fix euler_a for batch size > 1 (diff)
downloadtextual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip
WIP: img2img
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py (renamed from pipelines/stable_diffusion/clip_guided_stable_diffusion.py)80
1 files changed, 64 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index eff74b5..4c793a8 100644
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -2,22 +2,29 @@ import inspect
2import warnings 2import warnings
3from typing import List, Optional, Union 3from typing import List, Optional, Union
4 4
5import numpy as np
5import torch 6import torch
6from torch import nn 7import PIL
7from torch.nn import functional as F
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from torchvision import transforms 13from transformers import CLIPTextModel, CLIPTokenizer
14from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward 14from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
16 15
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18 17
19 18
20class CLIPGuidedStableDiffusion(DiffusionPipeline): 19def preprocess(image, w, h):
20 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
21 image = np.array(image).astype(np.float32) / 255.0
22 image = image[None].transpose(0, 3, 1, 2)
23 image = torch.from_numpy(image)
24 return 2.0 * image - 1.0
25
26
27class VlpnStableDiffusion(DiffusionPipeline):
21 def __init__( 28 def __init__(
22 self, 29 self,
23 vae: AutoencoderKL, 30 vae: AutoencoderKL,
@@ -83,13 +90,14 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
83 self, 90 self,
84 prompt: Union[str, List[str]], 91 prompt: Union[str, List[str]],
85 negative_prompt: Optional[Union[str, List[str]]] = None, 92 negative_prompt: Optional[Union[str, List[str]]] = None,
93 strength: float = 0.8,
86 height: Optional[int] = 512, 94 height: Optional[int] = 512,
87 width: Optional[int] = 512, 95 width: Optional[int] = 512,
88 num_inference_steps: Optional[int] = 50, 96 num_inference_steps: Optional[int] = 50,
89 guidance_scale: Optional[float] = 7.5, 97 guidance_scale: Optional[float] = 7.5,
90 eta: Optional[float] = 0.0, 98 eta: Optional[float] = 0.0,
91 generator: Optional[torch.Generator] = None, 99 generator: Optional[torch.Generator] = None,
92 latents: Optional[torch.FloatTensor] = None, 100 latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
93 output_type: Optional[str] = "pil", 101 output_type: Optional[str] = "pil",
94 return_dict: bool = True, 102 return_dict: bool = True,
95 ): 103 ):
@@ -99,6 +107,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
99 Args: 107 Args:
100 prompt (`str` or `List[str]`): 108 prompt (`str` or `List[str]`):
101 The prompt or prompts to guide the image generation. 109 The prompt or prompts to guide the image generation.
110 strength (`float`, *optional*, defaults to 0.8):
111 Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
112 `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
113 number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
114 noise will be maximum and the denoising process will run for the full number of iterations specified in
115 `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
102 height (`int`, *optional*, defaults to 512): 116 height (`int`, *optional*, defaults to 512):
103 The height in pixels of the generated image. 117 The height in pixels of the generated image.
104 width (`int`, *optional*, defaults to 512): 118 width (`int`, *optional*, defaults to 512):
@@ -158,6 +172,42 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
158 if height % 8 != 0 or width % 8 != 0: 172 if height % 8 != 0 or width % 8 != 0:
159 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 173 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
160 174
175 if strength < 0 or strength > 1:
176 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
177
178 # set timesteps
179 self.scheduler.set_timesteps(num_inference_steps)
180
181 offset = self.scheduler.config.get("steps_offset", 0)
182
183 if latents is not None and isinstance(latents, PIL.Image.Image):
184 latents = preprocess(latents, width, height)
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents
188 latents = torch.cat([latents] * batch_size)
189
190 # get the original timestep using init_timestep
191 init_timestep = int(num_inference_steps * strength) + offset
192 init_timestep = min(init_timestep, num_inference_steps)
193
194 if isinstance(self.scheduler, LMSDiscreteScheduler):
195 timesteps = torch.tensor(
196 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
197 )
198 elif isinstance(self.scheduler, EulerAScheduler):
199 timesteps = self.scheduler.timesteps[-init_timestep]
200 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
201 else:
202 timesteps = self.scheduler.timesteps[-init_timestep]
203 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
204
205 # add noise to latents using the timesteps
206 noise = torch.randn(latents.shape, generator=generator, device=self.device)
207 latents = self.scheduler.add_noise(latents, noise, timesteps)
208 else:
209 init_timestep = num_inference_steps + offset
210
161 # get prompt text embeddings 211 # get prompt text embeddings
162 text_inputs = self.tokenizer( 212 text_inputs = self.tokenizer(
163 prompt, 213 prompt,
@@ -213,15 +263,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
213 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 263 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
214 latents = latents.to(self.device) 264 latents = latents.to(self.device)
215 265
216 # set timesteps 266 t_start = max(num_inference_steps - init_timestep + offset, 0)
217 self.scheduler.set_timesteps(num_inference_steps)
218 267
219 # Some schedulers like PNDM have timesteps as arrays 268 # Some schedulers like PNDM have timesteps as arrays
220 # It's more optimzed to move all timesteps to correct device beforehand 269 # It's more optimzed to move all timesteps to correct device beforehand
221 if torch.is_tensor(self.scheduler.timesteps): 270 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
222 timesteps_tensor = self.scheduler.timesteps.to(self.device)
223 else:
224 timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
225 271
226 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 272 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
227 if isinstance(self.scheduler, LMSDiscreteScheduler): 273 if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -244,10 +290,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
244 extra_step_kwargs["generator"] = generator 290 extra_step_kwargs["generator"] = generator
245 291
246 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 292 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
293 t_index = t_start + i
294
247 # expand the latents if we are doing classifier free guidance 295 # expand the latents if we are doing classifier free guidance
248 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 296 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249 if isinstance(self.scheduler, LMSDiscreteScheduler): 297 if isinstance(self.scheduler, LMSDiscreteScheduler):
250 sigma = self.scheduler.sigmas[i] 298 sigma = self.scheduler.sigmas[t_index]
251 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 299 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
252 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) 300 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
253 301
@@ -270,10 +318,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
270 318
271 # compute the previous noisy sample x_t -> x_t-1 319 # compute the previous noisy sample x_t -> x_t-1
272 if isinstance(self.scheduler, LMSDiscreteScheduler): 320 if isinstance(self.scheduler, LMSDiscreteScheduler):
273 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 321 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
274 elif isinstance(self.scheduler, EulerAScheduler): 322 elif isinstance(self.scheduler, EulerAScheduler):
275 if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error 323 if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
276 t_prev = self.scheduler.timesteps[i+1] 324 t_prev = self.scheduler.timesteps[t_index+1]
277 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample 325 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
278 else: 326 else:
279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 327 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample