summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-14 18:41:38 +0100
committerVolpeon <git@volpeon.ink>2022-11-14 18:41:38 +0100
commit8ff51a771905d0d14a3c690f54eb644515730348 (patch)
treef1096181e912291f85d82d95af88a9f4257c1b35 /pipelines
parentUpdate (diff)
downloadtextual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.gz
textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.bz2
textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.zip
Refactoring
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py318
1 files changed, 214 insertions, 104 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index ba057ba..d6b1cb1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,6 +1,6 @@
1import inspect 1import inspect
2import warnings 2import warnings
3from typing import List, Optional, Union 3from typing import List, Optional, Union, Callable
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
@@ -136,11 +136,165 @@ class VlpnStableDiffusion(DiffusionPipeline):
136 if cpu_offloaded_model is not None: 136 if cpu_offloaded_model is not None:
137 cpu_offload(cpu_offloaded_model, device) 137 cpu_offload(cpu_offloaded_model, device)
138 138
139 @property
140 def execution_device(self):
141 r"""
142 Returns the device on which the pipeline's models will be executed. After calling
143 `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
144 hooks.
145 """
146 if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
147 return self.device
148 for module in self.unet.modules():
149 if (
150 hasattr(module, "_hf_hook")
151 and hasattr(module._hf_hook, "execution_device")
152 and module._hf_hook.execution_device is not None
153 ):
154 return torch.device(module._hf_hook.execution_device)
155 return self.device
156
157 def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps):
158 if isinstance(prompt, str):
159 prompt = [prompt]
160
161 if negative_prompt is None:
162 negative_prompt = ""
163
164 if isinstance(negative_prompt, str):
165 negative_prompt = [negative_prompt] * len(prompt)
166
167 if not isinstance(prompt, list):
168 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
169
170 if not isinstance(negative_prompt, list):
171 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
172
173 if len(negative_prompt) != len(prompt):
174 raise ValueError(
175 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
176
177 if strength < 0 or strength > 1:
178 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
179
180 if height % 8 != 0 or width % 8 != 0:
181 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
182
183 if (callback_steps is None) or (
184 callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
185 ):
186 raise ValueError(
187 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
188 f" {type(callback_steps)}."
189 )
190
191 return prompt, negative_prompt
192
193 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance):
194 text_input_ids = self.prompt_processor.get_input_ids(prompt)
195 text_input_ids *= num_images_per_prompt
196
197 if do_classifier_free_guidance:
198 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt)
199 unconditional_input_ids *= num_images_per_prompt
200 text_input_ids = unconditional_input_ids + text_input_ids
201
202 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids)
203 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids)
204
205 return text_embeddings
206
207 def get_timesteps(self, latents_are_image, num_inference_steps, strength, device):
208 if latents_are_image:
209 # get the original timestep using init_timestep
210 offset = self.scheduler.config.get("steps_offset", 0)
211 init_timestep = int(num_inference_steps * strength) + offset
212 init_timestep = min(init_timestep, num_inference_steps)
213
214 t_start = max(num_inference_steps - init_timestep + offset, 0)
215 timesteps = self.scheduler.timesteps[t_start:]
216 else:
217 timesteps = self.scheduler.timesteps
218
219 timesteps = timesteps.to(device)
220
221 return timesteps
222
223 def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None):
224 shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
225
226 if latents is None:
227 if device.type == "mps":
228 # randn does not work reproducibly on mps
229 latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
230 else:
231 latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
232 else:
233 if latents.shape != shape:
234 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
235 latents = latents.to(device)
236
237 # scale the initial noise by the standard deviation required by the scheduler
238 latents = latents * self.scheduler.init_noise_sigma
239
240 return latents
241
242 def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
243 init_image = init_image.to(device=device, dtype=dtype)
244 init_latent_dist = self.vae.encode(init_image).latent_dist
245 init_latents = init_latent_dist.sample(generator=generator)
246 init_latents = 0.18215 * init_latents
247
248 if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
249 additional_image_per_prompt = batch_size // init_latents.shape[0]
250 init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
251 elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
252 raise ValueError(
253 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
254 )
255 else:
256 init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
257
258 # add noise to latents using the timesteps
259 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
260
261 # get latents
262 init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
263 latents = init_latents
264
265 return latents
266
267 def prepare_extra_step_kwargs(self, generator, eta):
268 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
269 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
270 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
271 # and should be between [0, 1]
272
273 accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
274 extra_step_kwargs = {}
275 if accepts_eta:
276 extra_step_kwargs["eta"] = eta
277
278 # check if the scheduler accepts generator
279 accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
280 if accepts_generator:
281 extra_step_kwargs["generator"] = generator
282 return extra_step_kwargs
283
284 def decode_latents(self, latents):
285 latents = 1 / 0.18215 * latents
286 image = self.vae.decode(latents).sample
287 image = (image / 2 + 0.5).clamp(0, 1)
288 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
289 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
290 return image
291
139 @torch.no_grad() 292 @torch.no_grad()
140 def __call__( 293 def __call__(
141 self, 294 self,
142 prompt: Union[str, List[str], List[List[str]]], 295 prompt: Union[str, List[str], List[List[str]]],
143 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, 296 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None,
297 num_images_per_prompt: Optional[int] = 1,
144 strength: float = 0.8, 298 strength: float = 0.8,
145 height: Optional[int] = 512, 299 height: Optional[int] = 512,
146 width: Optional[int] = 512, 300 width: Optional[int] = 512,
@@ -148,9 +302,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
148 guidance_scale: Optional[float] = 7.5, 302 guidance_scale: Optional[float] = 7.5,
149 eta: Optional[float] = 0.0, 303 eta: Optional[float] = 0.0,
150 generator: Optional[torch.Generator] = None, 304 generator: Optional[torch.Generator] = None,
151 latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 305 latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
152 output_type: Optional[str] = "pil", 306 output_type: Optional[str] = "pil",
153 return_dict: bool = True, 307 return_dict: bool = True,
308 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
309 callback_steps: Optional[int] = 1,
154 ): 310 ):
155 r""" 311 r"""
156 Function invoked when calling the pipeline for generation. 312 Function invoked when calling the pipeline for generation.
@@ -202,110 +358,60 @@ class VlpnStableDiffusion(DiffusionPipeline):
202 (nsfw) content, according to the `safety_checker`. 358 (nsfw) content, according to the `safety_checker`.
203 """ 359 """
204 360
205 if isinstance(prompt, str): 361 # 1. Check inputs. Raise error if not correct
206 prompt = [prompt] 362 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps)
207 363
364 # 2. Define call parameters
208 batch_size = len(prompt) 365 batch_size = len(prompt)
209 366 device = self.execution_device
210 if negative_prompt is None:
211 negative_prompt = ""
212
213 if isinstance(negative_prompt, str):
214 negative_prompt = [negative_prompt] * batch_size
215
216 if len(negative_prompt) != len(prompt):
217 raise ValueError(
218 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
219
220 if height % 8 != 0 or width % 8 != 0:
221 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
222
223 if strength < 0 or strength > 1:
224 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
225
226 # set timesteps
227 self.scheduler.set_timesteps(num_inference_steps)
228
229 # get prompt text embeddings
230 text_input_ids = self.prompt_processor.get_input_ids(prompt)
231
232 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
233 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
234 # corresponds to doing no classifier free guidance.
235 do_classifier_free_guidance = guidance_scale > 1.0 367 do_classifier_free_guidance = guidance_scale > 1.0
236 # get unconditional embeddings for classifier free guidance 368 latents_are_image = isinstance(latents_or_image, PIL.Image.Image)
237 if do_classifier_free_guidance:
238 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt)
239 text_input_ids = unconditional_input_ids + text_input_ids
240
241 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids)
242 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids)
243
244 offset = self.scheduler.config.get("steps_offset", 0)
245 init_timestep = num_inference_steps + offset
246
247 # get the initial random noise unless the user supplied it
248
249 # Unlike in other pipelines, latents need to be generated in the target device
250 # for 1-to-1 results reproducibility with the CompVis implementation.
251 # However this currently doesn't work in `mps`.
252 latents_dtype = text_embeddings.dtype
253 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
254 369
255 if latents is None: 370 print(f">>> {device}")
256 if self.device.type == "mps":
257 # randn does not exist on mps
258 latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
259 self.device
260 )
261 else:
262 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
263 elif isinstance(latents, PIL.Image.Image):
264 latents = preprocess(latents, width, height)
265 latents = latents.to(device=self.device, dtype=latents_dtype)
266 latent_dist = self.vae.encode(latents).latent_dist
267 latents = latent_dist.sample(generator=generator)
268 latents = 0.18215 * latents
269
270 # expand init_latents for batch_size
271 latents = torch.cat([latents] * batch_size, dim=0)
272
273 # get the original timestep using init_timestep
274 init_timestep = int(num_inference_steps * strength) + offset
275 init_timestep = min(init_timestep, num_inference_steps)
276 371
277 timesteps = self.scheduler.timesteps[-init_timestep] 372 # 3. Encode input prompt
278 timesteps = torch.tensor([timesteps] * batch_size, device=self.device) 373 text_embeddings = self.encode_prompt(
374 prompt,
375 negative_prompt,
376 num_images_per_prompt,
377 do_classifier_free_guidance
378 )
279 379
280 # add noise to latents using the timesteps 380 # 4. Prepare timesteps
281 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) 381 self.scheduler.set_timesteps(num_inference_steps, device=device)
282 latents = self.scheduler.add_noise(latents, noise, timesteps) 382 timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device)
383
384 # 5. Prepare latent variables
385 num_channels_latents = self.unet.in_channels
386 if latents_are_image:
387 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
388 latents = self.prepare_latents_from_image(
389 latents_or_image,
390 latent_timestep,
391 batch_size,
392 num_images_per_prompt,
393 text_embeddings.dtype,
394 device,
395 generator
396 )
283 else: 397 else:
284 if latents.shape != latents_shape: 398 latents = self.prepare_latents(
285 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 399 batch_size,
286 if latents.device != self.device: 400 num_images_per_prompt,
287 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") 401 num_channels_latents,
288 402 height,
289 t_start = max(num_inference_steps - init_timestep + offset, 0) 403 width,
290 404 text_embeddings.dtype,
291 # Some schedulers like PNDM have timesteps as arrays 405 device,
292 # It's more optimzed to move all timesteps to correct device beforehand 406 generator,
293 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) 407 latents_or_image,
408 )
294 409
295 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 410 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
296 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 411 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
297 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
298 # and should be between [0, 1]
299 scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys())
300 accepts_eta = "eta" in scheduler_step_args
301 extra_step_kwargs = {}
302 if accepts_eta:
303 extra_step_kwargs["eta"] = eta
304 accepts_generator = "generator" in scheduler_step_args
305 if generator is not None and accepts_generator:
306 extra_step_kwargs["generator"] = generator
307 412
308 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 413 # 7. Denoising loop
414 for i, t in enumerate(self.progress_bar(timesteps)):
309 # expand the latents if we are doing classifier free guidance 415 # expand the latents if we are doing classifier free guidance
310 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 416 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
311 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 417 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -321,17 +427,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
321 # compute the previous noisy sample x_t -> x_t-1 427 # compute the previous noisy sample x_t -> x_t-1
322 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 428 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
323 429
324 # scale and decode the image latents with vae 430 # call the callback, if provided
325 latents = 1 / 0.18215 * latents 431 if callback is not None and i % callback_steps == 0:
326 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample 432 callback(i, t, latents)
327 433
328 image = (image / 2 + 0.5).clamp(0, 1) 434 # 8. Post-processing
329 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 435 image = self.decode_latents(latents)
436
437 # 9. Run safety checker
438 has_nsfw_concept = None
330 439
440 # 10. Convert to PIL
331 if output_type == "pil": 441 if output_type == "pil":
332 image = self.numpy_to_pil(image) 442 image = self.numpy_to_pil(image)
333 443
334 if not return_dict: 444 if not return_dict:
335 return (image, None) 445 return (image, has_nsfw_concept)
336 446
337 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) 447 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)