summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion/vlpn_stable_diffusion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
commit49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch)
tree3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /pipelines/stable_diffusion/vlpn_stable_diffusion.py
parentFix seed, better progress bar, fix euler_a for batch size > 1 (diff)
downloadtextual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip
WIP: img2img
Diffstat (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py342
1 files changed, 342 insertions, 0 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
new file mode 100644
index 0000000..4c793a8
--- /dev/null
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -0,0 +1,342 @@
1import inspect
2import warnings
3from typing import List, Optional, Union
4
5import numpy as np
6import torch
7import PIL
8
9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
15
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
18
19def preprocess(image, w, h):
20 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
21 image = np.array(image).astype(np.float32) / 255.0
22 image = image[None].transpose(0, 3, 1, 2)
23 image = torch.from_numpy(image)
24 return 2.0 * image - 1.0
25
26
27class VlpnStableDiffusion(DiffusionPipeline):
28 def __init__(
29 self,
30 vae: AutoencoderKL,
31 text_encoder: CLIPTextModel,
32 tokenizer: CLIPTokenizer,
33 unet: UNet2DConditionModel,
34 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler],
35 **kwargs,
36 ):
37 super().__init__()
38
39 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
40 warnings.warn(
41 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
42 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
43 "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
44 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
45 " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
46 " file",
47 DeprecationWarning,
48 )
49 new_config = dict(scheduler.config)
50 new_config["steps_offset"] = 1
51 scheduler._internal_dict = FrozenDict(new_config)
52
53 self.register_modules(
54 vae=vae,
55 text_encoder=text_encoder,
56 tokenizer=tokenizer,
57 unet=unet,
58 scheduler=scheduler,
59 )
60
61 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
62 r"""
63 Enable sliced attention computation.
64
65 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
66 in several steps. This is useful to save some memory in exchange for a small speed decrease.
67
68 Args:
69 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
70 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
71 a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
72 `attention_head_dim` must be a multiple of `slice_size`.
73 """
74 if slice_size == "auto":
75 # half the attention head size is usually a good trade-off between
76 # speed and memory
77 slice_size = self.unet.config.attention_head_dim // 2
78 self.unet.set_attention_slice(slice_size)
79
80 def disable_attention_slicing(self):
81 r"""
82 Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
83 back to computing attention in one step.
84 """
85 # set slice_size = `None` to disable `attention slicing`
86 self.enable_attention_slicing(None)
87
88 @torch.no_grad()
89 def __call__(
90 self,
91 prompt: Union[str, List[str]],
92 negative_prompt: Optional[Union[str, List[str]]] = None,
93 strength: float = 0.8,
94 height: Optional[int] = 512,
95 width: Optional[int] = 512,
96 num_inference_steps: Optional[int] = 50,
97 guidance_scale: Optional[float] = 7.5,
98 eta: Optional[float] = 0.0,
99 generator: Optional[torch.Generator] = None,
100 latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
101 output_type: Optional[str] = "pil",
102 return_dict: bool = True,
103 ):
104 r"""
105 Function invoked when calling the pipeline for generation.
106
107 Args:
108 prompt (`str` or `List[str]`):
109 The prompt or prompts to guide the image generation.
110 strength (`float`, *optional*, defaults to 0.8):
111 Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
112 `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
113 number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
114 noise will be maximum and the denoising process will run for the full number of iterations specified in
115 `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
116 height (`int`, *optional*, defaults to 512):
117 The height in pixels of the generated image.
118 width (`int`, *optional*, defaults to 512):
119 The width in pixels of the generated image.
120 num_inference_steps (`int`, *optional*, defaults to 50):
121 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
122 expense of slower inference.
123 guidance_scale (`float`, *optional*, defaults to 7.5):
124 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
125 `guidance_scale` is defined as `w` of equation 2. of [Imagen
126 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
127 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
128 usually at the expense of lower image quality.
129 eta (`float`, *optional*, defaults to 0.0):
130 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
131 [`schedulers.DDIMScheduler`], will be ignored for others.
132 generator (`torch.Generator`, *optional*):
133 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
134 deterministic.
135 latents (`torch.FloatTensor`, *optional*):
136 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
137 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
138 tensor will ge generated by sampling using the supplied random `generator`.
139 output_type (`str`, *optional*, defaults to `"pil"`):
140 The output format of the generate image. Choose between
141 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
142 return_dict (`bool`, *optional*, defaults to `True`):
143 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
144 plain tuple.
145
146 Returns:
147 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
148 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
149 When returning a tuple, the first element is a list with the generated images, and the second element is a
150 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
151 (nsfw) content, according to the `safety_checker`.
152 """
153
154 if isinstance(prompt, str):
155 batch_size = 1
156 elif isinstance(prompt, list):
157 batch_size = len(prompt)
158 else:
159 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
160
161 if negative_prompt is None:
162 negative_prompt = [""] * batch_size
163 elif isinstance(negative_prompt, str):
164 negative_prompt = [negative_prompt] * batch_size
165 elif isinstance(negative_prompt, list):
166 if len(negative_prompt) != batch_size:
167 raise ValueError(
168 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
169 else:
170 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
171
172 if height % 8 != 0 or width % 8 != 0:
173 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
174
175 if strength < 0 or strength > 1:
176 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
177
178 # set timesteps
179 self.scheduler.set_timesteps(num_inference_steps)
180
181 offset = self.scheduler.config.get("steps_offset", 0)
182
183 if latents is not None and isinstance(latents, PIL.Image.Image):
184 latents = preprocess(latents, width, height)
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents
188 latents = torch.cat([latents] * batch_size)
189
190 # get the original timestep using init_timestep
191 init_timestep = int(num_inference_steps * strength) + offset
192 init_timestep = min(init_timestep, num_inference_steps)
193
194 if isinstance(self.scheduler, LMSDiscreteScheduler):
195 timesteps = torch.tensor(
196 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
197 )
198 elif isinstance(self.scheduler, EulerAScheduler):
199 timesteps = self.scheduler.timesteps[-init_timestep]
200 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
201 else:
202 timesteps = self.scheduler.timesteps[-init_timestep]
203 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
204
205 # add noise to latents using the timesteps
206 noise = torch.randn(latents.shape, generator=generator, device=self.device)
207 latents = self.scheduler.add_noise(latents, noise, timesteps)
208 else:
209 init_timestep = num_inference_steps + offset
210
211 # get prompt text embeddings
212 text_inputs = self.tokenizer(
213 prompt,
214 padding="max_length",
215 max_length=self.tokenizer.model_max_length,
216 return_tensors="pt",
217 )
218 text_input_ids = text_inputs.input_ids
219
220 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
221 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
222 logger.warning(
223 "The following part of your input was truncated because CLIP can only handle sequences up to"
224 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
225 )
226 print(f"Too many tokens: {removed_text}")
227 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
228 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
229
230 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
231 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
232 # corresponds to doing no classifier free guidance.
233 do_classifier_free_guidance = guidance_scale > 1.0
234 # get unconditional embeddings for classifier free guidance
235 if do_classifier_free_guidance:
236 max_length = text_input_ids.shape[-1]
237 uncond_input = self.tokenizer(
238 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
239 )
240 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
241
242 # For classifier free guidance, we need to do two forward passes.
243 # Here we concatenate the unconditional and text embeddings into a single batch
244 # to avoid doing two forward passes
245 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
246
247 # get the initial random noise unless the user supplied it
248
249 # Unlike in other pipelines, latents need to be generated in the target device
250 # for 1-to-1 results reproducibility with the CompVis implementation.
251 # However this currently doesn't work in `mps`.
252 latents_device = "cpu" if self.device.type == "mps" else self.device
253 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
254 if latents is None:
255 latents = torch.randn(
256 latents_shape,
257 generator=generator,
258 device=latents_device,
259 dtype=text_embeddings.dtype,
260 )
261 else:
262 if latents.shape != latents_shape:
263 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
264 latents = latents.to(self.device)
265
266 t_start = max(num_inference_steps - init_timestep + offset, 0)
267
268 # Some schedulers like PNDM have timesteps as arrays
269 # It's more optimzed to move all timesteps to correct device beforehand
270 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
271
272 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
273 if isinstance(self.scheduler, LMSDiscreteScheduler):
274 latents = latents * self.scheduler.sigmas[0]
275 elif isinstance(self.scheduler, EulerAScheduler):
276 sigma = self.scheduler.timesteps[0]
277 latents = latents * sigma
278
279 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
280 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
281 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
282 # and should be between [0, 1]
283 scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys())
284 accepts_eta = "eta" in scheduler_step_args
285 extra_step_kwargs = {}
286 if accepts_eta:
287 extra_step_kwargs["eta"] = eta
288 accepts_generator = "generator" in scheduler_step_args
289 if generator is not None and accepts_generator:
290 extra_step_kwargs["generator"] = generator
291
292 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
293 t_index = t_start + i
294
295 # expand the latents if we are doing classifier free guidance
296 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
297 if isinstance(self.scheduler, LMSDiscreteScheduler):
298 sigma = self.scheduler.sigmas[t_index]
299 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
300 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
301
302 noise_pred = None
303 if isinstance(self.scheduler, EulerAScheduler):
304 sigma = t.reshape(1)
305 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
306 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
307 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
308 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
309 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
310 else:
311 # predict the noise residual
312 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
313
314 # perform guidance
315 if do_classifier_free_guidance:
316 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
317 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
318
319 # compute the previous noisy sample x_t -> x_t-1
320 if isinstance(self.scheduler, LMSDiscreteScheduler):
321 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
322 elif isinstance(self.scheduler, EulerAScheduler):
323 if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
324 t_prev = self.scheduler.timesteps[t_index+1]
325 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
326 else:
327 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
328
329 # scale and decode the image latents with vae
330 latents = 1 / 0.18215 * latents
331 image = self.vae.decode(latents).sample
332
333 image = (image / 2 + 0.5).clamp(0, 1)
334 image = image.cpu().permute(0, 2, 3, 1).numpy()
335
336 if output_type == "pil":
337 image = self.numpy_to_pil(image)
338
339 if not return_dict:
340 return (image, None)
341
342 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)