diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
| commit | 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch) | |
| tree | ad186862f5095663966dd1d42455023080aa0c4e /pipelines | |
| parent | Better sample file structure (diff) | |
| download | textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2 textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip | |
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 457 |
1 files changed, 457 insertions, 0 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py new file mode 100644 index 0000000..306d9a9 --- /dev/null +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
| @@ -0,0 +1,457 @@ | |||
| 1 | import inspect | ||
| 2 | import warnings | ||
| 3 | from typing import List, Optional, Union | ||
| 4 | |||
| 5 | import torch | ||
| 6 | from torch import nn | ||
| 7 | from torch.nn import functional as F | ||
| 8 | |||
| 9 | from diffusers.configuration_utils import FrozenDict | ||
| 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | ||
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | ||
| 12 | from diffusers.utils import logging | ||
| 13 | from torchvision import transforms | ||
| 14 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | ||
| 15 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | ||
| 16 | |||
| 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
| 18 | |||
| 19 | |||
| 20 | class MakeCutouts(nn.Module): | ||
| 21 | def __init__(self, cut_size, cut_power=1.0): | ||
| 22 | super().__init__() | ||
| 23 | |||
| 24 | self.cut_size = cut_size | ||
| 25 | self.cut_power = cut_power | ||
| 26 | |||
| 27 | def forward(self, pixel_values, num_cutouts): | ||
| 28 | sideY, sideX = pixel_values.shape[2:4] | ||
| 29 | max_size = min(sideX, sideY) | ||
| 30 | min_size = min(sideX, sideY, self.cut_size) | ||
| 31 | cutouts = [] | ||
| 32 | for _ in range(num_cutouts): | ||
| 33 | size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | ||
| 34 | offsetx = torch.randint(0, sideX - size + 1, ()) | ||
| 35 | offsety = torch.randint(0, sideY - size + 1, ()) | ||
| 36 | cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] | ||
| 37 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | ||
| 38 | return torch.cat(cutouts) | ||
| 39 | |||
| 40 | |||
| 41 | def spherical_dist_loss(x, y): | ||
| 42 | x = F.normalize(x, dim=-1) | ||
| 43 | y = F.normalize(y, dim=-1) | ||
| 44 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | ||
| 45 | |||
| 46 | |||
| 47 | def set_requires_grad(model, value): | ||
| 48 | for param in model.parameters(): | ||
| 49 | param.requires_grad = value | ||
| 50 | |||
| 51 | |||
| 52 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | ||
| 53 | """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | ||
| 54 | - https://github.com/Jack000/glid-3-xl | ||
| 55 | - https://github.dev/crowsonkb/k-diffusion | ||
| 56 | """ | ||
| 57 | |||
| 58 | def __init__( | ||
| 59 | self, | ||
| 60 | vae: AutoencoderKL, | ||
| 61 | text_encoder: CLIPTextModel, | ||
| 62 | clip_model: CLIPModel, | ||
| 63 | tokenizer: CLIPTokenizer, | ||
| 64 | unet: UNet2DConditionModel, | ||
| 65 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | ||
| 66 | feature_extractor: CLIPFeatureExtractor, | ||
| 67 | **kwargs, | ||
| 68 | ): | ||
| 69 | super().__init__() | ||
| 70 | |||
| 71 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | ||
| 72 | warnings.warn( | ||
| 73 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | ||
| 74 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | ||
| 75 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | ||
| 76 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | ||
| 77 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | ||
| 78 | " file", | ||
| 79 | DeprecationWarning, | ||
| 80 | ) | ||
| 81 | new_config = dict(scheduler.config) | ||
| 82 | new_config["steps_offset"] = 1 | ||
| 83 | scheduler._internal_dict = FrozenDict(new_config) | ||
| 84 | |||
| 85 | self.register_modules( | ||
| 86 | vae=vae, | ||
| 87 | text_encoder=text_encoder, | ||
| 88 | clip_model=clip_model, | ||
| 89 | tokenizer=tokenizer, | ||
| 90 | unet=unet, | ||
| 91 | scheduler=scheduler, | ||
| 92 | feature_extractor=feature_extractor, | ||
| 93 | ) | ||
| 94 | |||
| 95 | self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | ||
| 96 | self.make_cutouts = MakeCutouts(feature_extractor.size) | ||
| 97 | |||
| 98 | set_requires_grad(self.text_encoder, False) | ||
| 99 | set_requires_grad(self.clip_model, False) | ||
| 100 | |||
| 101 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | ||
| 102 | r""" | ||
| 103 | Enable sliced attention computation. | ||
| 104 | |||
| 105 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | ||
| 106 | in several steps. This is useful to save some memory in exchange for a small speed decrease. | ||
| 107 | |||
| 108 | Args: | ||
| 109 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | ||
| 110 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | ||
| 111 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, | ||
| 112 | `attention_head_dim` must be a multiple of `slice_size`. | ||
| 113 | """ | ||
| 114 | if slice_size == "auto": | ||
| 115 | # half the attention head size is usually a good trade-off between | ||
| 116 | # speed and memory | ||
| 117 | slice_size = self.unet.config.attention_head_dim // 2 | ||
| 118 | self.unet.set_attention_slice(slice_size) | ||
| 119 | |||
| 120 | def disable_attention_slicing(self): | ||
| 121 | r""" | ||
| 122 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go | ||
| 123 | back to computing attention in one step. | ||
| 124 | """ | ||
| 125 | # set slice_size = `None` to disable `attention slicing` | ||
| 126 | self.enable_attention_slicing(None) | ||
| 127 | |||
| 128 | def freeze_vae(self): | ||
| 129 | set_requires_grad(self.vae, False) | ||
| 130 | |||
| 131 | def unfreeze_vae(self): | ||
| 132 | set_requires_grad(self.vae, True) | ||
| 133 | |||
| 134 | def freeze_unet(self): | ||
| 135 | set_requires_grad(self.unet, False) | ||
| 136 | |||
| 137 | def unfreeze_unet(self): | ||
| 138 | set_requires_grad(self.unet, True) | ||
| 139 | |||
| 140 | @torch.enable_grad() | ||
| 141 | def cond_fn( | ||
| 142 | self, | ||
| 143 | latents, | ||
| 144 | timestep, | ||
| 145 | index, | ||
| 146 | text_embeddings, | ||
| 147 | noise_pred_original, | ||
| 148 | text_embeddings_clip, | ||
| 149 | clip_guidance_scale, | ||
| 150 | num_cutouts, | ||
| 151 | use_cutouts=True, | ||
| 152 | ): | ||
| 153 | latents = latents.detach().requires_grad_() | ||
| 154 | |||
| 155 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 156 | sigma = self.scheduler.sigmas[index] | ||
| 157 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
| 158 | latent_model_input = latents / ((sigma**2 + 1) ** 0.5) | ||
| 159 | else: | ||
| 160 | latent_model_input = latents | ||
| 161 | |||
| 162 | # predict the noise residual | ||
| 163 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | ||
| 164 | |||
| 165 | if isinstance(self.scheduler, PNDMScheduler): | ||
| 166 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
| 167 | beta_prod_t = 1 - alpha_prod_t | ||
| 168 | # compute predicted original sample from predicted noise also called | ||
| 169 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
| 170 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | ||
| 171 | |||
| 172 | fac = torch.sqrt(beta_prod_t) | ||
| 173 | sample = pred_original_sample * (fac) + latents * (1 - fac) | ||
| 174 | elif isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 175 | sigma = self.scheduler.sigmas[index] | ||
| 176 | sample = latents - sigma * noise_pred | ||
| 177 | else: | ||
| 178 | raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | ||
| 179 | |||
| 180 | sample = 1 / 0.18215 * sample | ||
| 181 | image = self.vae.decode(sample).sample | ||
| 182 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 183 | |||
| 184 | if use_cutouts: | ||
| 185 | image = self.make_cutouts(image, num_cutouts) | ||
| 186 | else: | ||
| 187 | image = transforms.Resize(self.feature_extractor.size)(image) | ||
| 188 | image = self.normalize(image) | ||
| 189 | |||
| 190 | image_embeddings_clip = self.clip_model.get_image_features(image).float() | ||
| 191 | image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
| 192 | |||
| 193 | if use_cutouts: | ||
| 194 | dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | ||
| 195 | dists = dists.view([num_cutouts, sample.shape[0], -1]) | ||
| 196 | loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | ||
| 197 | else: | ||
| 198 | loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | ||
| 199 | |||
| 200 | grads = -torch.autograd.grad(loss, latents)[0] | ||
| 201 | |||
| 202 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 203 | latents = latents.detach() + grads * (sigma**2) | ||
| 204 | noise_pred = noise_pred_original | ||
| 205 | else: | ||
| 206 | noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | ||
| 207 | return noise_pred, latents | ||
| 208 | |||
| 209 | @torch.no_grad() | ||
| 210 | def __call__( | ||
| 211 | self, | ||
| 212 | prompt: Union[str, List[str]], | ||
| 213 | negative_prompt: Optional[Union[str, List[str]]] = None, | ||
| 214 | height: Optional[int] = 512, | ||
| 215 | width: Optional[int] = 512, | ||
| 216 | num_inference_steps: Optional[int] = 50, | ||
| 217 | guidance_scale: Optional[float] = 7.5, | ||
| 218 | eta: Optional[float] = 0.0, | ||
| 219 | clip_guidance_scale: Optional[float] = 100, | ||
| 220 | clip_prompt: Optional[Union[str, List[str]]] = None, | ||
| 221 | num_cutouts: Optional[int] = 4, | ||
| 222 | use_cutouts: Optional[bool] = True, | ||
| 223 | generator: Optional[torch.Generator] = None, | ||
| 224 | latents: Optional[torch.FloatTensor] = None, | ||
| 225 | output_type: Optional[str] = "pil", | ||
| 226 | return_dict: bool = True, | ||
| 227 | ): | ||
| 228 | r""" | ||
| 229 | Function invoked when calling the pipeline for generation. | ||
| 230 | |||
| 231 | Args: | ||
| 232 | prompt (`str` or `List[str]`): | ||
| 233 | The prompt or prompts to guide the image generation. | ||
| 234 | height (`int`, *optional*, defaults to 512): | ||
| 235 | The height in pixels of the generated image. | ||
| 236 | width (`int`, *optional*, defaults to 512): | ||
| 237 | The width in pixels of the generated image. | ||
| 238 | num_inference_steps (`int`, *optional*, defaults to 50): | ||
| 239 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| 240 | expense of slower inference. | ||
| 241 | guidance_scale (`float`, *optional*, defaults to 7.5): | ||
| 242 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
| 243 | `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
| 244 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
| 245 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
| 246 | usually at the expense of lower image quality. | ||
| 247 | eta (`float`, *optional*, defaults to 0.0): | ||
| 248 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||
| 249 | [`schedulers.DDIMScheduler`], will be ignored for others. | ||
| 250 | generator (`torch.Generator`, *optional*): | ||
| 251 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
| 252 | deterministic. | ||
| 253 | latents (`torch.FloatTensor`, *optional*): | ||
| 254 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
| 255 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
| 256 | tensor will ge generated by sampling using the supplied random `generator`. | ||
| 257 | output_type (`str`, *optional*, defaults to `"pil"`): | ||
| 258 | The output format of the generate image. Choose between | ||
| 259 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
| 260 | return_dict (`bool`, *optional*, defaults to `True`): | ||
| 261 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||
| 262 | plain tuple. | ||
| 263 | |||
| 264 | Returns: | ||
| 265 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||
| 266 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||
| 267 | When returning a tuple, the first element is a list with the generated images, and the second element is a | ||
| 268 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||
| 269 | (nsfw) content, according to the `safety_checker`. | ||
| 270 | """ | ||
| 271 | |||
| 272 | if isinstance(prompt, str): | ||
| 273 | batch_size = 1 | ||
| 274 | elif isinstance(prompt, list): | ||
| 275 | batch_size = len(prompt) | ||
| 276 | else: | ||
| 277 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| 278 | |||
| 279 | if negative_prompt is None: | ||
| 280 | negative_prompt = [""] * batch_size | ||
| 281 | elif isinstance(negative_prompt, str): | ||
| 282 | negative_prompt = [negative_prompt] * batch_size | ||
| 283 | elif isinstance(negative_prompt, list): | ||
| 284 | if len(negative_prompt) != batch_size: | ||
| 285 | raise ValueError( | ||
| 286 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
| 287 | else: | ||
| 288 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
| 289 | |||
| 290 | if height % 8 != 0 or width % 8 != 0: | ||
| 291 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
| 292 | |||
| 293 | # get prompt text embeddings | ||
| 294 | text_inputs = self.tokenizer( | ||
| 295 | prompt, | ||
| 296 | padding="max_length", | ||
| 297 | max_length=self.tokenizer.model_max_length, | ||
| 298 | return_tensors="pt", | ||
| 299 | ) | ||
| 300 | text_input_ids = text_inputs.input_ids | ||
| 301 | |||
| 302 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
| 303 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) | ||
| 304 | logger.warning( | ||
| 305 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
| 306 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
| 307 | ) | ||
| 308 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | ||
| 309 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | ||
| 310 | |||
| 311 | if clip_guidance_scale > 0: | ||
| 312 | if clip_prompt is not None: | ||
| 313 | clip_text_inputs = self.tokenizer( | ||
| 314 | clip_prompt, | ||
| 315 | padding="max_length", | ||
| 316 | max_length=self.tokenizer.model_max_length, | ||
| 317 | truncation=True, | ||
| 318 | return_tensors="pt", | ||
| 319 | ) | ||
| 320 | clip_text_input_ids = clip_text_inputs.input_ids | ||
| 321 | else: | ||
| 322 | clip_text_input_ids = text_input_ids | ||
| 323 | text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) | ||
| 324 | text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
| 325 | |||
| 326 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
| 327 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
| 328 | # corresponds to doing no classifier free guidance. | ||
| 329 | do_classifier_free_guidance = guidance_scale > 1.0 | ||
| 330 | # get unconditional embeddings for classifier free guidance | ||
| 331 | if do_classifier_free_guidance: | ||
| 332 | max_length = text_input_ids.shape[-1] | ||
| 333 | uncond_input = self.tokenizer( | ||
| 334 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | ||
| 335 | ) | ||
| 336 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
| 337 | |||
| 338 | # For classifier free guidance, we need to do two forward passes. | ||
| 339 | # Here we concatenate the unconditional and text embeddings into a single batch | ||
| 340 | # to avoid doing two forward passes | ||
| 341 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
| 342 | |||
| 343 | # get the initial random noise unless the user supplied it | ||
| 344 | |||
| 345 | # Unlike in other pipelines, latents need to be generated in the target device | ||
| 346 | # for 1-to-1 results reproducibility with the CompVis implementation. | ||
| 347 | # However this currently doesn't work in `mps`. | ||
| 348 | latents_device = "cpu" if self.device.type == "mps" else self.device | ||
| 349 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
| 350 | if latents is None: | ||
| 351 | latents = torch.randn( | ||
| 352 | latents_shape, | ||
| 353 | generator=generator, | ||
| 354 | device=latents_device, | ||
| 355 | dtype=text_embeddings.dtype, | ||
| 356 | ) | ||
| 357 | else: | ||
| 358 | if latents.shape != latents_shape: | ||
| 359 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
| 360 | latents = latents.to(self.device) | ||
| 361 | |||
| 362 | # set timesteps | ||
| 363 | self.scheduler.set_timesteps(num_inference_steps) | ||
| 364 | |||
| 365 | # Some schedulers like PNDM have timesteps as arrays | ||
| 366 | # It's more optimzed to move all timesteps to correct device beforehand | ||
| 367 | if torch.is_tensor(self.scheduler.timesteps): | ||
| 368 | timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
| 369 | else: | ||
| 370 | timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
| 371 | |||
| 372 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| 373 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 374 | latents = latents * self.scheduler.sigmas[0] | ||
| 375 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 376 | sigma = self.scheduler.timesteps[0] | ||
| 377 | latents = latents * sigma | ||
| 378 | |||
| 379 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
| 380 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
| 381 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
| 382 | # and should be between [0, 1] | ||
| 383 | scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
| 384 | accepts_eta = "eta" in scheduler_step_args | ||
| 385 | extra_step_kwargs = {} | ||
| 386 | if accepts_eta: | ||
| 387 | extra_step_kwargs["eta"] = eta | ||
| 388 | accepts_generator = "generator" in scheduler_step_args | ||
| 389 | if generator is not None and accepts_generator: | ||
| 390 | extra_step_kwargs["generator"] = generator | ||
| 391 | |||
| 392 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | ||
| 393 | # expand the latents if we are doing classifier free guidance | ||
| 394 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| 395 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 396 | sigma = self.scheduler.sigmas[i] | ||
| 397 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
| 398 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
| 399 | |||
| 400 | noise_pred = None | ||
| 401 | if isinstance(self.scheduler, EulerAScheduler): | ||
| 402 | sigma = t.reshape(1) | ||
| 403 | sigma_in = torch.cat([sigma] * 2) | ||
| 404 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
| 405 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | ||
| 406 | text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) | ||
| 407 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
| 408 | else: | ||
| 409 | # predict the noise residual | ||
| 410 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
| 411 | |||
| 412 | # perform guidance | ||
| 413 | if do_classifier_free_guidance: | ||
| 414 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
| 415 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
| 416 | |||
| 417 | # perform clip guidance | ||
| 418 | if clip_guidance_scale > 0: | ||
| 419 | text_embeddings_for_guidance = ( | ||
| 420 | text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | ||
| 421 | ) | ||
| 422 | noise_pred, latents = self.cond_fn( | ||
| 423 | latents, | ||
| 424 | t, | ||
| 425 | i, | ||
| 426 | text_embeddings_for_guidance, | ||
| 427 | noise_pred, | ||
| 428 | text_embeddings_clip, | ||
| 429 | clip_guidance_scale, | ||
| 430 | num_cutouts, | ||
| 431 | use_cutouts, | ||
| 432 | ) | ||
| 433 | |||
| 434 | # compute the previous noisy sample x_t -> x_t-1 | ||
| 435 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 436 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | ||
| 437 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 438 | if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | ||
| 439 | t_prev = self.scheduler.timesteps[i+1] | ||
| 440 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
| 441 | else: | ||
| 442 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
| 443 | |||
| 444 | # scale and decode the image latents with vae | ||
| 445 | latents = 1 / 0.18215 * latents | ||
| 446 | image = self.vae.decode(latents).sample | ||
| 447 | |||
| 448 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 449 | image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
| 450 | |||
| 451 | if output_type == "pil": | ||
| 452 | image = self.numpy_to_pil(image) | ||
| 453 | |||
| 454 | if not return_dict: | ||
| 455 | return (image, None) | ||
| 456 | |||
| 457 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | ||
