diff options
author | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
commit | 49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch) | |
tree | 3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 /pipelines/stable_diffusion/vlpn_stable_diffusion.py | |
parent | Fix seed, better progress bar, fix euler_a for batch size > 1 (diff) | |
download | textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2 textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip |
WIP: img2img
Diffstat (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 342 |
1 files changed, 342 insertions, 0 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py new file mode 100644 index 0000000..4c793a8 --- /dev/null +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -0,0 +1,342 @@ | |||
1 | import inspect | ||
2 | import warnings | ||
3 | from typing import List, Optional, Union | ||
4 | |||
5 | import numpy as np | ||
6 | import torch | ||
7 | import PIL | ||
8 | |||
9 | from diffusers.configuration_utils import FrozenDict | ||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | ||
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | ||
12 | from diffusers.utils import logging | ||
13 | from transformers import CLIPTextModel, CLIPTokenizer | ||
14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | ||
15 | |||
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
17 | |||
18 | |||
19 | def preprocess(image, w, h): | ||
20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
21 | image = np.array(image).astype(np.float32) / 255.0 | ||
22 | image = image[None].transpose(0, 3, 1, 2) | ||
23 | image = torch.from_numpy(image) | ||
24 | return 2.0 * image - 1.0 | ||
25 | |||
26 | |||
27 | class VlpnStableDiffusion(DiffusionPipeline): | ||
28 | def __init__( | ||
29 | self, | ||
30 | vae: AutoencoderKL, | ||
31 | text_encoder: CLIPTextModel, | ||
32 | tokenizer: CLIPTokenizer, | ||
33 | unet: UNet2DConditionModel, | ||
34 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], | ||
35 | **kwargs, | ||
36 | ): | ||
37 | super().__init__() | ||
38 | |||
39 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | ||
40 | warnings.warn( | ||
41 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | ||
42 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | ||
43 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | ||
44 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | ||
45 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | ||
46 | " file", | ||
47 | DeprecationWarning, | ||
48 | ) | ||
49 | new_config = dict(scheduler.config) | ||
50 | new_config["steps_offset"] = 1 | ||
51 | scheduler._internal_dict = FrozenDict(new_config) | ||
52 | |||
53 | self.register_modules( | ||
54 | vae=vae, | ||
55 | text_encoder=text_encoder, | ||
56 | tokenizer=tokenizer, | ||
57 | unet=unet, | ||
58 | scheduler=scheduler, | ||
59 | ) | ||
60 | |||
61 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | ||
62 | r""" | ||
63 | Enable sliced attention computation. | ||
64 | |||
65 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | ||
66 | in several steps. This is useful to save some memory in exchange for a small speed decrease. | ||
67 | |||
68 | Args: | ||
69 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | ||
70 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | ||
71 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, | ||
72 | `attention_head_dim` must be a multiple of `slice_size`. | ||
73 | """ | ||
74 | if slice_size == "auto": | ||
75 | # half the attention head size is usually a good trade-off between | ||
76 | # speed and memory | ||
77 | slice_size = self.unet.config.attention_head_dim // 2 | ||
78 | self.unet.set_attention_slice(slice_size) | ||
79 | |||
80 | def disable_attention_slicing(self): | ||
81 | r""" | ||
82 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go | ||
83 | back to computing attention in one step. | ||
84 | """ | ||
85 | # set slice_size = `None` to disable `attention slicing` | ||
86 | self.enable_attention_slicing(None) | ||
87 | |||
88 | @torch.no_grad() | ||
89 | def __call__( | ||
90 | self, | ||
91 | prompt: Union[str, List[str]], | ||
92 | negative_prompt: Optional[Union[str, List[str]]] = None, | ||
93 | strength: float = 0.8, | ||
94 | height: Optional[int] = 512, | ||
95 | width: Optional[int] = 512, | ||
96 | num_inference_steps: Optional[int] = 50, | ||
97 | guidance_scale: Optional[float] = 7.5, | ||
98 | eta: Optional[float] = 0.0, | ||
99 | generator: Optional[torch.Generator] = None, | ||
100 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | ||
101 | output_type: Optional[str] = "pil", | ||
102 | return_dict: bool = True, | ||
103 | ): | ||
104 | r""" | ||
105 | Function invoked when calling the pipeline for generation. | ||
106 | |||
107 | Args: | ||
108 | prompt (`str` or `List[str]`): | ||
109 | The prompt or prompts to guide the image generation. | ||
110 | strength (`float`, *optional*, defaults to 0.8): | ||
111 | Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. | ||
112 | `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The | ||
113 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | ||
114 | noise will be maximum and the denoising process will run for the full number of iterations specified in | ||
115 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | ||
116 | height (`int`, *optional*, defaults to 512): | ||
117 | The height in pixels of the generated image. | ||
118 | width (`int`, *optional*, defaults to 512): | ||
119 | The width in pixels of the generated image. | ||
120 | num_inference_steps (`int`, *optional*, defaults to 50): | ||
121 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
122 | expense of slower inference. | ||
123 | guidance_scale (`float`, *optional*, defaults to 7.5): | ||
124 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
125 | `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
126 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
127 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
128 | usually at the expense of lower image quality. | ||
129 | eta (`float`, *optional*, defaults to 0.0): | ||
130 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||
131 | [`schedulers.DDIMScheduler`], will be ignored for others. | ||
132 | generator (`torch.Generator`, *optional*): | ||
133 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
134 | deterministic. | ||
135 | latents (`torch.FloatTensor`, *optional*): | ||
136 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
137 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
138 | tensor will ge generated by sampling using the supplied random `generator`. | ||
139 | output_type (`str`, *optional*, defaults to `"pil"`): | ||
140 | The output format of the generate image. Choose between | ||
141 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
142 | return_dict (`bool`, *optional*, defaults to `True`): | ||
143 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||
144 | plain tuple. | ||
145 | |||
146 | Returns: | ||
147 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||
148 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||
149 | When returning a tuple, the first element is a list with the generated images, and the second element is a | ||
150 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||
151 | (nsfw) content, according to the `safety_checker`. | ||
152 | """ | ||
153 | |||
154 | if isinstance(prompt, str): | ||
155 | batch_size = 1 | ||
156 | elif isinstance(prompt, list): | ||
157 | batch_size = len(prompt) | ||
158 | else: | ||
159 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
160 | |||
161 | if negative_prompt is None: | ||
162 | negative_prompt = [""] * batch_size | ||
163 | elif isinstance(negative_prompt, str): | ||
164 | negative_prompt = [negative_prompt] * batch_size | ||
165 | elif isinstance(negative_prompt, list): | ||
166 | if len(negative_prompt) != batch_size: | ||
167 | raise ValueError( | ||
168 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
169 | else: | ||
170 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
171 | |||
172 | if height % 8 != 0 or width % 8 != 0: | ||
173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
174 | |||
175 | if strength < 0 or strength > 1: | ||
176 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
177 | |||
178 | # set timesteps | ||
179 | self.scheduler.set_timesteps(num_inference_steps) | ||
180 | |||
181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
182 | |||
183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
184 | latents = preprocess(latents, width, height) | ||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
186 | latents = latent_dist.sample(generator=generator) | ||
187 | latents = 0.18215 * latents | ||
188 | latents = torch.cat([latents] * batch_size) | ||
189 | |||
190 | # get the original timestep using init_timestep | ||
191 | init_timestep = int(num_inference_steps * strength) + offset | ||
192 | init_timestep = min(init_timestep, num_inference_steps) | ||
193 | |||
194 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
195 | timesteps = torch.tensor( | ||
196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
197 | ) | ||
198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
201 | else: | ||
202 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
204 | |||
205 | # add noise to latents using the timesteps | ||
206 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
207 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
208 | else: | ||
209 | init_timestep = num_inference_steps + offset | ||
210 | |||
211 | # get prompt text embeddings | ||
212 | text_inputs = self.tokenizer( | ||
213 | prompt, | ||
214 | padding="max_length", | ||
215 | max_length=self.tokenizer.model_max_length, | ||
216 | return_tensors="pt", | ||
217 | ) | ||
218 | text_input_ids = text_inputs.input_ids | ||
219 | |||
220 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
221 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) | ||
222 | logger.warning( | ||
223 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
224 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
225 | ) | ||
226 | print(f"Too many tokens: {removed_text}") | ||
227 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | ||
228 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | ||
229 | |||
230 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
231 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
232 | # corresponds to doing no classifier free guidance. | ||
233 | do_classifier_free_guidance = guidance_scale > 1.0 | ||
234 | # get unconditional embeddings for classifier free guidance | ||
235 | if do_classifier_free_guidance: | ||
236 | max_length = text_input_ids.shape[-1] | ||
237 | uncond_input = self.tokenizer( | ||
238 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | ||
239 | ) | ||
240 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
241 | |||
242 | # For classifier free guidance, we need to do two forward passes. | ||
243 | # Here we concatenate the unconditional and text embeddings into a single batch | ||
244 | # to avoid doing two forward passes | ||
245 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
246 | |||
247 | # get the initial random noise unless the user supplied it | ||
248 | |||
249 | # Unlike in other pipelines, latents need to be generated in the target device | ||
250 | # for 1-to-1 results reproducibility with the CompVis implementation. | ||
251 | # However this currently doesn't work in `mps`. | ||
252 | latents_device = "cpu" if self.device.type == "mps" else self.device | ||
253 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
254 | if latents is None: | ||
255 | latents = torch.randn( | ||
256 | latents_shape, | ||
257 | generator=generator, | ||
258 | device=latents_device, | ||
259 | dtype=text_embeddings.dtype, | ||
260 | ) | ||
261 | else: | ||
262 | if latents.shape != latents_shape: | ||
263 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
264 | latents = latents.to(self.device) | ||
265 | |||
266 | t_start = max(num_inference_steps - init_timestep + offset, 0) | ||
267 | |||
268 | # Some schedulers like PNDM have timesteps as arrays | ||
269 | # It's more optimzed to move all timesteps to correct device beforehand | ||
270 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | ||
271 | |||
272 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
273 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
274 | latents = latents * self.scheduler.sigmas[0] | ||
275 | elif isinstance(self.scheduler, EulerAScheduler): | ||
276 | sigma = self.scheduler.timesteps[0] | ||
277 | latents = latents * sigma | ||
278 | |||
279 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
280 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
281 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
282 | # and should be between [0, 1] | ||
283 | scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
284 | accepts_eta = "eta" in scheduler_step_args | ||
285 | extra_step_kwargs = {} | ||
286 | if accepts_eta: | ||
287 | extra_step_kwargs["eta"] = eta | ||
288 | accepts_generator = "generator" in scheduler_step_args | ||
289 | if generator is not None and accepts_generator: | ||
290 | extra_step_kwargs["generator"] = generator | ||
291 | |||
292 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | ||
293 | t_index = t_start + i | ||
294 | |||
295 | # expand the latents if we are doing classifier free guidance | ||
296 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
297 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
298 | sigma = self.scheduler.sigmas[t_index] | ||
299 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
300 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
301 | |||
302 | noise_pred = None | ||
303 | if isinstance(self.scheduler, EulerAScheduler): | ||
304 | sigma = t.reshape(1) | ||
305 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | ||
306 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | ||
308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | ||
309 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
310 | else: | ||
311 | # predict the noise residual | ||
312 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
313 | |||
314 | # perform guidance | ||
315 | if do_classifier_free_guidance: | ||
316 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
317 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
318 | |||
319 | # compute the previous noisy sample x_t -> x_t-1 | ||
320 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | ||
322 | elif isinstance(self.scheduler, EulerAScheduler): | ||
323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | ||
324 | t_prev = self.scheduler.timesteps[t_index+1] | ||
325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
326 | else: | ||
327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
328 | |||
329 | # scale and decode the image latents with vae | ||
330 | latents = 1 / 0.18215 * latents | ||
331 | image = self.vae.decode(latents).sample | ||
332 | |||
333 | image = (image / 2 + 0.5).clamp(0, 1) | ||
334 | image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
335 | |||
336 | if output_type == "pil": | ||
337 | image = self.numpy_to_pil(image) | ||
338 | |||
339 | if not return_dict: | ||
340 | return (image, None) | ||
341 | |||
342 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | ||