diff options
author | Volpeon <git@volpeon.ink> | 2022-11-14 18:41:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-14 18:41:38 +0100 |
commit | 8ff51a771905d0d14a3c690f54eb644515730348 (patch) | |
tree | f1096181e912291f85d82d95af88a9f4257c1b35 /pipelines/stable_diffusion | |
parent | Update (diff) | |
download | textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.gz textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.bz2 textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.zip |
Refactoring
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 318 |
1 files changed, 214 insertions, 104 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ba057ba..d6b1cb1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -1,6 +1,6 @@ | |||
1 | import inspect | 1 | import inspect |
2 | import warnings | 2 | import warnings |
3 | from typing import List, Optional, Union | 3 | from typing import List, Optional, Union, Callable |
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
@@ -136,11 +136,165 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
136 | if cpu_offloaded_model is not None: | 136 | if cpu_offloaded_model is not None: |
137 | cpu_offload(cpu_offloaded_model, device) | 137 | cpu_offload(cpu_offloaded_model, device) |
138 | 138 | ||
139 | @property | ||
140 | def execution_device(self): | ||
141 | r""" | ||
142 | Returns the device on which the pipeline's models will be executed. After calling | ||
143 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module | ||
144 | hooks. | ||
145 | """ | ||
146 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): | ||
147 | return self.device | ||
148 | for module in self.unet.modules(): | ||
149 | if ( | ||
150 | hasattr(module, "_hf_hook") | ||
151 | and hasattr(module._hf_hook, "execution_device") | ||
152 | and module._hf_hook.execution_device is not None | ||
153 | ): | ||
154 | return torch.device(module._hf_hook.execution_device) | ||
155 | return self.device | ||
156 | |||
157 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | ||
158 | if isinstance(prompt, str): | ||
159 | prompt = [prompt] | ||
160 | |||
161 | if negative_prompt is None: | ||
162 | negative_prompt = "" | ||
163 | |||
164 | if isinstance(negative_prompt, str): | ||
165 | negative_prompt = [negative_prompt] * len(prompt) | ||
166 | |||
167 | if not isinstance(prompt, list): | ||
168 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
169 | |||
170 | if not isinstance(negative_prompt, list): | ||
171 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
172 | |||
173 | if len(negative_prompt) != len(prompt): | ||
174 | raise ValueError( | ||
175 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
176 | |||
177 | if strength < 0 or strength > 1: | ||
178 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
179 | |||
180 | if height % 8 != 0 or width % 8 != 0: | ||
181 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
182 | |||
183 | if (callback_steps is None) or ( | ||
184 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
185 | ): | ||
186 | raise ValueError( | ||
187 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | ||
188 | f" {type(callback_steps)}." | ||
189 | ) | ||
190 | |||
191 | return prompt, negative_prompt | ||
192 | |||
193 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | ||
194 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | ||
195 | text_input_ids *= num_images_per_prompt | ||
196 | |||
197 | if do_classifier_free_guidance: | ||
198 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | ||
199 | unconditional_input_ids *= num_images_per_prompt | ||
200 | text_input_ids = unconditional_input_ids + text_input_ids | ||
201 | |||
202 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | ||
203 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | ||
204 | |||
205 | return text_embeddings | ||
206 | |||
207 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): | ||
208 | if latents_are_image: | ||
209 | # get the original timestep using init_timestep | ||
210 | offset = self.scheduler.config.get("steps_offset", 0) | ||
211 | init_timestep = int(num_inference_steps * strength) + offset | ||
212 | init_timestep = min(init_timestep, num_inference_steps) | ||
213 | |||
214 | t_start = max(num_inference_steps - init_timestep + offset, 0) | ||
215 | timesteps = self.scheduler.timesteps[t_start:] | ||
216 | else: | ||
217 | timesteps = self.scheduler.timesteps | ||
218 | |||
219 | timesteps = timesteps.to(device) | ||
220 | |||
221 | return timesteps | ||
222 | |||
223 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | ||
224 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | ||
225 | |||
226 | if latents is None: | ||
227 | if device.type == "mps": | ||
228 | # randn does not work reproducibly on mps | ||
229 | latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | ||
230 | else: | ||
231 | latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | ||
232 | else: | ||
233 | if latents.shape != shape: | ||
234 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
235 | latents = latents.to(device) | ||
236 | |||
237 | # scale the initial noise by the standard deviation required by the scheduler | ||
238 | latents = latents * self.scheduler.init_noise_sigma | ||
239 | |||
240 | return latents | ||
241 | |||
242 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | ||
243 | init_image = init_image.to(device=device, dtype=dtype) | ||
244 | init_latent_dist = self.vae.encode(init_image).latent_dist | ||
245 | init_latents = init_latent_dist.sample(generator=generator) | ||
246 | init_latents = 0.18215 * init_latents | ||
247 | |||
248 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: | ||
249 | additional_image_per_prompt = batch_size // init_latents.shape[0] | ||
250 | init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) | ||
251 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: | ||
252 | raise ValueError( | ||
253 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | ||
254 | ) | ||
255 | else: | ||
256 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | ||
257 | |||
258 | # add noise to latents using the timesteps | ||
259 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | ||
260 | |||
261 | # get latents | ||
262 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | ||
263 | latents = init_latents | ||
264 | |||
265 | return latents | ||
266 | |||
267 | def prepare_extra_step_kwargs(self, generator, eta): | ||
268 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
269 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
270 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
271 | # and should be between [0, 1] | ||
272 | |||
273 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
274 | extra_step_kwargs = {} | ||
275 | if accepts_eta: | ||
276 | extra_step_kwargs["eta"] = eta | ||
277 | |||
278 | # check if the scheduler accepts generator | ||
279 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
280 | if accepts_generator: | ||
281 | extra_step_kwargs["generator"] = generator | ||
282 | return extra_step_kwargs | ||
283 | |||
284 | def decode_latents(self, latents): | ||
285 | latents = 1 / 0.18215 * latents | ||
286 | image = self.vae.decode(latents).sample | ||
287 | image = (image / 2 + 0.5).clamp(0, 1) | ||
288 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
289 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
290 | return image | ||
291 | |||
139 | @torch.no_grad() | 292 | @torch.no_grad() |
140 | def __call__( | 293 | def __call__( |
141 | self, | 294 | self, |
142 | prompt: Union[str, List[str], List[List[str]]], | 295 | prompt: Union[str, List[str], List[List[str]]], |
143 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 296 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
297 | num_images_per_prompt: Optional[int] = 1, | ||
144 | strength: float = 0.8, | 298 | strength: float = 0.8, |
145 | height: Optional[int] = 512, | 299 | height: Optional[int] = 512, |
146 | width: Optional[int] = 512, | 300 | width: Optional[int] = 512, |
@@ -148,9 +302,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
148 | guidance_scale: Optional[float] = 7.5, | 302 | guidance_scale: Optional[float] = 7.5, |
149 | eta: Optional[float] = 0.0, | 303 | eta: Optional[float] = 0.0, |
150 | generator: Optional[torch.Generator] = None, | 304 | generator: Optional[torch.Generator] = None, |
151 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 305 | latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
152 | output_type: Optional[str] = "pil", | 306 | output_type: Optional[str] = "pil", |
153 | return_dict: bool = True, | 307 | return_dict: bool = True, |
308 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
309 | callback_steps: Optional[int] = 1, | ||
154 | ): | 310 | ): |
155 | r""" | 311 | r""" |
156 | Function invoked when calling the pipeline for generation. | 312 | Function invoked when calling the pipeline for generation. |
@@ -202,110 +358,60 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
202 | (nsfw) content, according to the `safety_checker`. | 358 | (nsfw) content, according to the `safety_checker`. |
203 | """ | 359 | """ |
204 | 360 | ||
205 | if isinstance(prompt, str): | 361 | # 1. Check inputs. Raise error if not correct |
206 | prompt = [prompt] | 362 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |
207 | 363 | ||
364 | # 2. Define call parameters | ||
208 | batch_size = len(prompt) | 365 | batch_size = len(prompt) |
209 | 366 | device = self.execution_device | |
210 | if negative_prompt is None: | ||
211 | negative_prompt = "" | ||
212 | |||
213 | if isinstance(negative_prompt, str): | ||
214 | negative_prompt = [negative_prompt] * batch_size | ||
215 | |||
216 | if len(negative_prompt) != len(prompt): | ||
217 | raise ValueError( | ||
218 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
219 | |||
220 | if height % 8 != 0 or width % 8 != 0: | ||
221 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
222 | |||
223 | if strength < 0 or strength > 1: | ||
224 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
225 | |||
226 | # set timesteps | ||
227 | self.scheduler.set_timesteps(num_inference_steps) | ||
228 | |||
229 | # get prompt text embeddings | ||
230 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | ||
231 | |||
232 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
233 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
234 | # corresponds to doing no classifier free guidance. | ||
235 | do_classifier_free_guidance = guidance_scale > 1.0 | 367 | do_classifier_free_guidance = guidance_scale > 1.0 |
236 | # get unconditional embeddings for classifier free guidance | 368 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) |
237 | if do_classifier_free_guidance: | ||
238 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | ||
239 | text_input_ids = unconditional_input_ids + text_input_ids | ||
240 | |||
241 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | ||
242 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | ||
243 | |||
244 | offset = self.scheduler.config.get("steps_offset", 0) | ||
245 | init_timestep = num_inference_steps + offset | ||
246 | |||
247 | # get the initial random noise unless the user supplied it | ||
248 | |||
249 | # Unlike in other pipelines, latents need to be generated in the target device | ||
250 | # for 1-to-1 results reproducibility with the CompVis implementation. | ||
251 | # However this currently doesn't work in `mps`. | ||
252 | latents_dtype = text_embeddings.dtype | ||
253 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
254 | 369 | ||
255 | if latents is None: | 370 | print(f">>> {device}") |
256 | if self.device.type == "mps": | ||
257 | # randn does not exist on mps | ||
258 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | ||
259 | self.device | ||
260 | ) | ||
261 | else: | ||
262 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
263 | elif isinstance(latents, PIL.Image.Image): | ||
264 | latents = preprocess(latents, width, height) | ||
265 | latents = latents.to(device=self.device, dtype=latents_dtype) | ||
266 | latent_dist = self.vae.encode(latents).latent_dist | ||
267 | latents = latent_dist.sample(generator=generator) | ||
268 | latents = 0.18215 * latents | ||
269 | |||
270 | # expand init_latents for batch_size | ||
271 | latents = torch.cat([latents] * batch_size, dim=0) | ||
272 | |||
273 | # get the original timestep using init_timestep | ||
274 | init_timestep = int(num_inference_steps * strength) + offset | ||
275 | init_timestep = min(init_timestep, num_inference_steps) | ||
276 | 371 | ||
277 | timesteps = self.scheduler.timesteps[-init_timestep] | 372 | # 3. Encode input prompt |
278 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 373 | text_embeddings = self.encode_prompt( |
374 | prompt, | ||
375 | negative_prompt, | ||
376 | num_images_per_prompt, | ||
377 | do_classifier_free_guidance | ||
378 | ) | ||
279 | 379 | ||
280 | # add noise to latents using the timesteps | 380 | # 4. Prepare timesteps |
281 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 381 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
282 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 382 | timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) |
383 | |||
384 | # 5. Prepare latent variables | ||
385 | num_channels_latents = self.unet.in_channels | ||
386 | if latents_are_image: | ||
387 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
388 | latents = self.prepare_latents_from_image( | ||
389 | latents_or_image, | ||
390 | latent_timestep, | ||
391 | batch_size, | ||
392 | num_images_per_prompt, | ||
393 | text_embeddings.dtype, | ||
394 | device, | ||
395 | generator | ||
396 | ) | ||
283 | else: | 397 | else: |
284 | if latents.shape != latents_shape: | 398 | latents = self.prepare_latents( |
285 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 399 | batch_size, |
286 | if latents.device != self.device: | 400 | num_images_per_prompt, |
287 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | 401 | num_channels_latents, |
288 | 402 | height, | |
289 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 403 | width, |
290 | 404 | text_embeddings.dtype, | |
291 | # Some schedulers like PNDM have timesteps as arrays | 405 | device, |
292 | # It's more optimzed to move all timesteps to correct device beforehand | 406 | generator, |
293 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 407 | latents_or_image, |
408 | ) | ||
294 | 409 | ||
295 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 410 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
296 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 411 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
297 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
298 | # and should be between [0, 1] | ||
299 | scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
300 | accepts_eta = "eta" in scheduler_step_args | ||
301 | extra_step_kwargs = {} | ||
302 | if accepts_eta: | ||
303 | extra_step_kwargs["eta"] = eta | ||
304 | accepts_generator = "generator" in scheduler_step_args | ||
305 | if generator is not None and accepts_generator: | ||
306 | extra_step_kwargs["generator"] = generator | ||
307 | 412 | ||
308 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 413 | # 7. Denoising loop |
414 | for i, t in enumerate(self.progress_bar(timesteps)): | ||
309 | # expand the latents if we are doing classifier free guidance | 415 | # expand the latents if we are doing classifier free guidance |
310 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 416 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
311 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 417 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
@@ -321,17 +427,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
321 | # compute the previous noisy sample x_t -> x_t-1 | 427 | # compute the previous noisy sample x_t -> x_t-1 |
322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 428 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
323 | 429 | ||
324 | # scale and decode the image latents with vae | 430 | # call the callback, if provided |
325 | latents = 1 / 0.18215 * latents | 431 | if callback is not None and i % callback_steps == 0: |
326 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 432 | callback(i, t, latents) |
327 | 433 | ||
328 | image = (image / 2 + 0.5).clamp(0, 1) | 434 | # 8. Post-processing |
329 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 435 | image = self.decode_latents(latents) |
436 | |||
437 | # 9. Run safety checker | ||
438 | has_nsfw_concept = None | ||
330 | 439 | ||
440 | # 10. Convert to PIL | ||
331 | if output_type == "pil": | 441 | if output_type == "pil": |
332 | image = self.numpy_to_pil(image) | 442 | image = self.numpy_to_pil(image) |
333 | 443 | ||
334 | if not return_dict: | 444 | if not return_dict: |
335 | return (image, None) | 445 | return (image, has_nsfw_concept) |
336 | 446 | ||
337 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | 447 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |