summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
committerVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
commit9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch)
treead186862f5095663966dd1d42455023080aa0c4e /pipelines
parentBetter sample file structure (diff)
downloadtextual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py457
1 files changed, 457 insertions, 0 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
new file mode 100644
index 0000000..306d9a9
--- /dev/null
+++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
@@ -0,0 +1,457 @@
1import inspect
2import warnings
3from typing import List, Optional, Union
4
5import torch
6from torch import nn
7from torch.nn import functional as F
8
9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging
13from torchvision import transforms
14from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
16
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
19
20class MakeCutouts(nn.Module):
21 def __init__(self, cut_size, cut_power=1.0):
22 super().__init__()
23
24 self.cut_size = cut_size
25 self.cut_power = cut_power
26
27 def forward(self, pixel_values, num_cutouts):
28 sideY, sideX = pixel_values.shape[2:4]
29 max_size = min(sideX, sideY)
30 min_size = min(sideX, sideY, self.cut_size)
31 cutouts = []
32 for _ in range(num_cutouts):
33 size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
34 offsetx = torch.randint(0, sideX - size + 1, ())
35 offsety = torch.randint(0, sideY - size + 1, ())
36 cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size]
37 cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
38 return torch.cat(cutouts)
39
40
41def spherical_dist_loss(x, y):
42 x = F.normalize(x, dim=-1)
43 y = F.normalize(y, dim=-1)
44 return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
45
46
47def set_requires_grad(model, value):
48 for param in model.parameters():
49 param.requires_grad = value
50
51
52class CLIPGuidedStableDiffusion(DiffusionPipeline):
53 """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
54 - https://github.com/Jack000/glid-3-xl
55 - https://github.dev/crowsonkb/k-diffusion
56 """
57
58 def __init__(
59 self,
60 vae: AutoencoderKL,
61 text_encoder: CLIPTextModel,
62 clip_model: CLIPModel,
63 tokenizer: CLIPTokenizer,
64 unet: UNet2DConditionModel,
65 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
66 feature_extractor: CLIPFeatureExtractor,
67 **kwargs,
68 ):
69 super().__init__()
70
71 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
72 warnings.warn(
73 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
74 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
75 "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
76 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
77 " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
78 " file",
79 DeprecationWarning,
80 )
81 new_config = dict(scheduler.config)
82 new_config["steps_offset"] = 1
83 scheduler._internal_dict = FrozenDict(new_config)
84
85 self.register_modules(
86 vae=vae,
87 text_encoder=text_encoder,
88 clip_model=clip_model,
89 tokenizer=tokenizer,
90 unet=unet,
91 scheduler=scheduler,
92 feature_extractor=feature_extractor,
93 )
94
95 self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
96 self.make_cutouts = MakeCutouts(feature_extractor.size)
97
98 set_requires_grad(self.text_encoder, False)
99 set_requires_grad(self.clip_model, False)
100
101 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
102 r"""
103 Enable sliced attention computation.
104
105 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
106 in several steps. This is useful to save some memory in exchange for a small speed decrease.
107
108 Args:
109 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
110 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
111 a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
112 `attention_head_dim` must be a multiple of `slice_size`.
113 """
114 if slice_size == "auto":
115 # half the attention head size is usually a good trade-off between
116 # speed and memory
117 slice_size = self.unet.config.attention_head_dim // 2
118 self.unet.set_attention_slice(slice_size)
119
120 def disable_attention_slicing(self):
121 r"""
122 Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
123 back to computing attention in one step.
124 """
125 # set slice_size = `None` to disable `attention slicing`
126 self.enable_attention_slicing(None)
127
128 def freeze_vae(self):
129 set_requires_grad(self.vae, False)
130
131 def unfreeze_vae(self):
132 set_requires_grad(self.vae, True)
133
134 def freeze_unet(self):
135 set_requires_grad(self.unet, False)
136
137 def unfreeze_unet(self):
138 set_requires_grad(self.unet, True)
139
140 @torch.enable_grad()
141 def cond_fn(
142 self,
143 latents,
144 timestep,
145 index,
146 text_embeddings,
147 noise_pred_original,
148 text_embeddings_clip,
149 clip_guidance_scale,
150 num_cutouts,
151 use_cutouts=True,
152 ):
153 latents = latents.detach().requires_grad_()
154
155 if isinstance(self.scheduler, LMSDiscreteScheduler):
156 sigma = self.scheduler.sigmas[index]
157 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
158 latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
159 else:
160 latent_model_input = latents
161
162 # predict the noise residual
163 noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
164
165 if isinstance(self.scheduler, PNDMScheduler):
166 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
167 beta_prod_t = 1 - alpha_prod_t
168 # compute predicted original sample from predicted noise also called
169 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
170 pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
171
172 fac = torch.sqrt(beta_prod_t)
173 sample = pred_original_sample * (fac) + latents * (1 - fac)
174 elif isinstance(self.scheduler, LMSDiscreteScheduler):
175 sigma = self.scheduler.sigmas[index]
176 sample = latents - sigma * noise_pred
177 else:
178 raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
179
180 sample = 1 / 0.18215 * sample
181 image = self.vae.decode(sample).sample
182 image = (image / 2 + 0.5).clamp(0, 1)
183
184 if use_cutouts:
185 image = self.make_cutouts(image, num_cutouts)
186 else:
187 image = transforms.Resize(self.feature_extractor.size)(image)
188 image = self.normalize(image)
189
190 image_embeddings_clip = self.clip_model.get_image_features(image).float()
191 image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
192
193 if use_cutouts:
194 dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
195 dists = dists.view([num_cutouts, sample.shape[0], -1])
196 loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
197 else:
198 loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
199
200 grads = -torch.autograd.grad(loss, latents)[0]
201
202 if isinstance(self.scheduler, LMSDiscreteScheduler):
203 latents = latents.detach() + grads * (sigma**2)
204 noise_pred = noise_pred_original
205 else:
206 noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
207 return noise_pred, latents
208
209 @torch.no_grad()
210 def __call__(
211 self,
212 prompt: Union[str, List[str]],
213 negative_prompt: Optional[Union[str, List[str]]] = None,
214 height: Optional[int] = 512,
215 width: Optional[int] = 512,
216 num_inference_steps: Optional[int] = 50,
217 guidance_scale: Optional[float] = 7.5,
218 eta: Optional[float] = 0.0,
219 clip_guidance_scale: Optional[float] = 100,
220 clip_prompt: Optional[Union[str, List[str]]] = None,
221 num_cutouts: Optional[int] = 4,
222 use_cutouts: Optional[bool] = True,
223 generator: Optional[torch.Generator] = None,
224 latents: Optional[torch.FloatTensor] = None,
225 output_type: Optional[str] = "pil",
226 return_dict: bool = True,
227 ):
228 r"""
229 Function invoked when calling the pipeline for generation.
230
231 Args:
232 prompt (`str` or `List[str]`):
233 The prompt or prompts to guide the image generation.
234 height (`int`, *optional*, defaults to 512):
235 The height in pixels of the generated image.
236 width (`int`, *optional*, defaults to 512):
237 The width in pixels of the generated image.
238 num_inference_steps (`int`, *optional*, defaults to 50):
239 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
240 expense of slower inference.
241 guidance_scale (`float`, *optional*, defaults to 7.5):
242 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
243 `guidance_scale` is defined as `w` of equation 2. of [Imagen
244 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
245 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
246 usually at the expense of lower image quality.
247 eta (`float`, *optional*, defaults to 0.0):
248 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
249 [`schedulers.DDIMScheduler`], will be ignored for others.
250 generator (`torch.Generator`, *optional*):
251 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
252 deterministic.
253 latents (`torch.FloatTensor`, *optional*):
254 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
255 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
256 tensor will ge generated by sampling using the supplied random `generator`.
257 output_type (`str`, *optional*, defaults to `"pil"`):
258 The output format of the generate image. Choose between
259 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
260 return_dict (`bool`, *optional*, defaults to `True`):
261 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
262 plain tuple.
263
264 Returns:
265 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
266 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
267 When returning a tuple, the first element is a list with the generated images, and the second element is a
268 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
269 (nsfw) content, according to the `safety_checker`.
270 """
271
272 if isinstance(prompt, str):
273 batch_size = 1
274 elif isinstance(prompt, list):
275 batch_size = len(prompt)
276 else:
277 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
278
279 if negative_prompt is None:
280 negative_prompt = [""] * batch_size
281 elif isinstance(negative_prompt, str):
282 negative_prompt = [negative_prompt] * batch_size
283 elif isinstance(negative_prompt, list):
284 if len(negative_prompt) != batch_size:
285 raise ValueError(
286 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
287 else:
288 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
289
290 if height % 8 != 0 or width % 8 != 0:
291 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
292
293 # get prompt text embeddings
294 text_inputs = self.tokenizer(
295 prompt,
296 padding="max_length",
297 max_length=self.tokenizer.model_max_length,
298 return_tensors="pt",
299 )
300 text_input_ids = text_inputs.input_ids
301
302 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
303 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
304 logger.warning(
305 "The following part of your input was truncated because CLIP can only handle sequences up to"
306 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
307 )
308 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
309 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
310
311 if clip_guidance_scale > 0:
312 if clip_prompt is not None:
313 clip_text_inputs = self.tokenizer(
314 clip_prompt,
315 padding="max_length",
316 max_length=self.tokenizer.model_max_length,
317 truncation=True,
318 return_tensors="pt",
319 )
320 clip_text_input_ids = clip_text_inputs.input_ids
321 else:
322 clip_text_input_ids = text_input_ids
323 text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device))
324 text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
325
326 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
327 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
328 # corresponds to doing no classifier free guidance.
329 do_classifier_free_guidance = guidance_scale > 1.0
330 # get unconditional embeddings for classifier free guidance
331 if do_classifier_free_guidance:
332 max_length = text_input_ids.shape[-1]
333 uncond_input = self.tokenizer(
334 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
335 )
336 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
337
338 # For classifier free guidance, we need to do two forward passes.
339 # Here we concatenate the unconditional and text embeddings into a single batch
340 # to avoid doing two forward passes
341 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
342
343 # get the initial random noise unless the user supplied it
344
345 # Unlike in other pipelines, latents need to be generated in the target device
346 # for 1-to-1 results reproducibility with the CompVis implementation.
347 # However this currently doesn't work in `mps`.
348 latents_device = "cpu" if self.device.type == "mps" else self.device
349 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
350 if latents is None:
351 latents = torch.randn(
352 latents_shape,
353 generator=generator,
354 device=latents_device,
355 dtype=text_embeddings.dtype,
356 )
357 else:
358 if latents.shape != latents_shape:
359 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
360 latents = latents.to(self.device)
361
362 # set timesteps
363 self.scheduler.set_timesteps(num_inference_steps)
364
365 # Some schedulers like PNDM have timesteps as arrays
366 # It's more optimzed to move all timesteps to correct device beforehand
367 if torch.is_tensor(self.scheduler.timesteps):
368 timesteps_tensor = self.scheduler.timesteps.to(self.device)
369 else:
370 timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
371
372 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
373 if isinstance(self.scheduler, LMSDiscreteScheduler):
374 latents = latents * self.scheduler.sigmas[0]
375 elif isinstance(self.scheduler, EulerAScheduler):
376 sigma = self.scheduler.timesteps[0]
377 latents = latents * sigma
378
379 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
380 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
381 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
382 # and should be between [0, 1]
383 scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys())
384 accepts_eta = "eta" in scheduler_step_args
385 extra_step_kwargs = {}
386 if accepts_eta:
387 extra_step_kwargs["eta"] = eta
388 accepts_generator = "generator" in scheduler_step_args
389 if generator is not None and accepts_generator:
390 extra_step_kwargs["generator"] = generator
391
392 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
393 # expand the latents if we are doing classifier free guidance
394 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
395 if isinstance(self.scheduler, LMSDiscreteScheduler):
396 sigma = self.scheduler.sigmas[i]
397 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
398 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
399
400 noise_pred = None
401 if isinstance(self.scheduler, EulerAScheduler):
402 sigma = t.reshape(1)
403 sigma_in = torch.cat([sigma] * 2)
404 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
405 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
406 text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas)
407 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
408 else:
409 # predict the noise residual
410 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
411
412 # perform guidance
413 if do_classifier_free_guidance:
414 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
415 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
416
417 # perform clip guidance
418 if clip_guidance_scale > 0:
419 text_embeddings_for_guidance = (
420 text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
421 )
422 noise_pred, latents = self.cond_fn(
423 latents,
424 t,
425 i,
426 text_embeddings_for_guidance,
427 noise_pred,
428 text_embeddings_clip,
429 clip_guidance_scale,
430 num_cutouts,
431 use_cutouts,
432 )
433
434 # compute the previous noisy sample x_t -> x_t-1
435 if isinstance(self.scheduler, LMSDiscreteScheduler):
436 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
437 elif isinstance(self.scheduler, EulerAScheduler):
438 if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
439 t_prev = self.scheduler.timesteps[i+1]
440 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
441 else:
442 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
443
444 # scale and decode the image latents with vae
445 latents = 1 / 0.18215 * latents
446 image = self.vae.decode(latents).sample
447
448 image = (image / 2 + 0.5).clamp(0, 1)
449 image = image.cpu().permute(0, 2, 3, 1).numpy()
450
451 if output_type == "pil":
452 image = self.numpy_to_pil(image)
453
454 if not return_dict:
455 return (image, None)
456
457 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)