summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 12:56:58 +0200
commit49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch)
tree3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /pipelines/stable_diffusion/clip_guided_stable_diffusion.py
parentFix seed, better progress bar, fix euler_a for batch size > 1 (diff)
downloadtextual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2
textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip
WIP: img2img
Diffstat (limited to 'pipelines/stable_diffusion/clip_guided_stable_diffusion.py')
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py294
1 files changed, 0 insertions, 294 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
deleted file mode 100644
index eff74b5..0000000
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ /dev/null
@@ -1,294 +0,0 @@
1import inspect
2import warnings
3from typing import List, Optional, Union
4
5import torch
6from torch import nn
7from torch.nn import functional as F
8
9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging
13from torchvision import transforms
14from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
16
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
19
20class CLIPGuidedStableDiffusion(DiffusionPipeline):
21 def __init__(
22 self,
23 vae: AutoencoderKL,
24 text_encoder: CLIPTextModel,
25 tokenizer: CLIPTokenizer,
26 unet: UNet2DConditionModel,
27 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler],
28 **kwargs,
29 ):
30 super().__init__()
31
32 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
33 warnings.warn(
34 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
35 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
36 "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
37 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
38 " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
39 " file",
40 DeprecationWarning,
41 )
42 new_config = dict(scheduler.config)
43 new_config["steps_offset"] = 1
44 scheduler._internal_dict = FrozenDict(new_config)
45
46 self.register_modules(
47 vae=vae,
48 text_encoder=text_encoder,
49 tokenizer=tokenizer,
50 unet=unet,
51 scheduler=scheduler,
52 )
53
54 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
55 r"""
56 Enable sliced attention computation.
57
58 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
59 in several steps. This is useful to save some memory in exchange for a small speed decrease.
60
61 Args:
62 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
63 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
64 a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
65 `attention_head_dim` must be a multiple of `slice_size`.
66 """
67 if slice_size == "auto":
68 # half the attention head size is usually a good trade-off between
69 # speed and memory
70 slice_size = self.unet.config.attention_head_dim // 2
71 self.unet.set_attention_slice(slice_size)
72
73 def disable_attention_slicing(self):
74 r"""
75 Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
76 back to computing attention in one step.
77 """
78 # set slice_size = `None` to disable `attention slicing`
79 self.enable_attention_slicing(None)
80
81 @torch.no_grad()
82 def __call__(
83 self,
84 prompt: Union[str, List[str]],
85 negative_prompt: Optional[Union[str, List[str]]] = None,
86 height: Optional[int] = 512,
87 width: Optional[int] = 512,
88 num_inference_steps: Optional[int] = 50,
89 guidance_scale: Optional[float] = 7.5,
90 eta: Optional[float] = 0.0,
91 generator: Optional[torch.Generator] = None,
92 latents: Optional[torch.FloatTensor] = None,
93 output_type: Optional[str] = "pil",
94 return_dict: bool = True,
95 ):
96 r"""
97 Function invoked when calling the pipeline for generation.
98
99 Args:
100 prompt (`str` or `List[str]`):
101 The prompt or prompts to guide the image generation.
102 height (`int`, *optional*, defaults to 512):
103 The height in pixels of the generated image.
104 width (`int`, *optional*, defaults to 512):
105 The width in pixels of the generated image.
106 num_inference_steps (`int`, *optional*, defaults to 50):
107 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
108 expense of slower inference.
109 guidance_scale (`float`, *optional*, defaults to 7.5):
110 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
111 `guidance_scale` is defined as `w` of equation 2. of [Imagen
112 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
113 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
114 usually at the expense of lower image quality.
115 eta (`float`, *optional*, defaults to 0.0):
116 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
117 [`schedulers.DDIMScheduler`], will be ignored for others.
118 generator (`torch.Generator`, *optional*):
119 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
120 deterministic.
121 latents (`torch.FloatTensor`, *optional*):
122 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
123 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
124 tensor will ge generated by sampling using the supplied random `generator`.
125 output_type (`str`, *optional*, defaults to `"pil"`):
126 The output format of the generate image. Choose between
127 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
128 return_dict (`bool`, *optional*, defaults to `True`):
129 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
130 plain tuple.
131
132 Returns:
133 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
134 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
135 When returning a tuple, the first element is a list with the generated images, and the second element is a
136 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
137 (nsfw) content, according to the `safety_checker`.
138 """
139
140 if isinstance(prompt, str):
141 batch_size = 1
142 elif isinstance(prompt, list):
143 batch_size = len(prompt)
144 else:
145 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
146
147 if negative_prompt is None:
148 negative_prompt = [""] * batch_size
149 elif isinstance(negative_prompt, str):
150 negative_prompt = [negative_prompt] * batch_size
151 elif isinstance(negative_prompt, list):
152 if len(negative_prompt) != batch_size:
153 raise ValueError(
154 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
155 else:
156 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
157
158 if height % 8 != 0 or width % 8 != 0:
159 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
160
161 # get prompt text embeddings
162 text_inputs = self.tokenizer(
163 prompt,
164 padding="max_length",
165 max_length=self.tokenizer.model_max_length,
166 return_tensors="pt",
167 )
168 text_input_ids = text_inputs.input_ids
169
170 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
171 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
172 logger.warning(
173 "The following part of your input was truncated because CLIP can only handle sequences up to"
174 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
175 )
176 print(f"Too many tokens: {removed_text}")
177 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
178 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
179
180 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
181 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
182 # corresponds to doing no classifier free guidance.
183 do_classifier_free_guidance = guidance_scale > 1.0
184 # get unconditional embeddings for classifier free guidance
185 if do_classifier_free_guidance:
186 max_length = text_input_ids.shape[-1]
187 uncond_input = self.tokenizer(
188 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
189 )
190 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
191
192 # For classifier free guidance, we need to do two forward passes.
193 # Here we concatenate the unconditional and text embeddings into a single batch
194 # to avoid doing two forward passes
195 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
196
197 # get the initial random noise unless the user supplied it
198
199 # Unlike in other pipelines, latents need to be generated in the target device
200 # for 1-to-1 results reproducibility with the CompVis implementation.
201 # However this currently doesn't work in `mps`.
202 latents_device = "cpu" if self.device.type == "mps" else self.device
203 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
204 if latents is None:
205 latents = torch.randn(
206 latents_shape,
207 generator=generator,
208 device=latents_device,
209 dtype=text_embeddings.dtype,
210 )
211 else:
212 if latents.shape != latents_shape:
213 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
214 latents = latents.to(self.device)
215
216 # set timesteps
217 self.scheduler.set_timesteps(num_inference_steps)
218
219 # Some schedulers like PNDM have timesteps as arrays
220 # It's more optimzed to move all timesteps to correct device beforehand
221 if torch.is_tensor(self.scheduler.timesteps):
222 timesteps_tensor = self.scheduler.timesteps.to(self.device)
223 else:
224 timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
225
226 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
227 if isinstance(self.scheduler, LMSDiscreteScheduler):
228 latents = latents * self.scheduler.sigmas[0]
229 elif isinstance(self.scheduler, EulerAScheduler):
230 sigma = self.scheduler.timesteps[0]
231 latents = latents * sigma
232
233 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
234 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
235 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
236 # and should be between [0, 1]
237 scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys())
238 accepts_eta = "eta" in scheduler_step_args
239 extra_step_kwargs = {}
240 if accepts_eta:
241 extra_step_kwargs["eta"] = eta
242 accepts_generator = "generator" in scheduler_step_args
243 if generator is not None and accepts_generator:
244 extra_step_kwargs["generator"] = generator
245
246 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
247 # expand the latents if we are doing classifier free guidance
248 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249 if isinstance(self.scheduler, LMSDiscreteScheduler):
250 sigma = self.scheduler.sigmas[i]
251 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
252 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
253
254 noise_pred = None
255 if isinstance(self.scheduler, EulerAScheduler):
256 sigma = t.reshape(1)
257 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
258 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
259 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
260 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
261 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
262 else:
263 # predict the noise residual
264 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
265
266 # perform guidance
267 if do_classifier_free_guidance:
268 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
269 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
270
271 # compute the previous noisy sample x_t -> x_t-1
272 if isinstance(self.scheduler, LMSDiscreteScheduler):
273 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
274 elif isinstance(self.scheduler, EulerAScheduler):
275 if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
276 t_prev = self.scheduler.timesteps[i+1]
277 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
278 else:
279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
280
281 # scale and decode the image latents with vae
282 latents = 1 / 0.18215 * latents
283 image = self.vae.decode(latents).sample
284
285 image = (image / 2 + 0.5).clamp(0, 1)
286 image = image.cpu().permute(0, 2, 3, 1).numpy()
287
288 if output_type == "pil":
289 image = self.numpy_to_pil(image)
290
291 if not return_dict:
292 return (image, None)
293
294 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)