diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 79 | unet=unet, | 79 | unet=unet, |
| 80 | scheduler=scheduler, | 80 | scheduler=scheduler, |
| 81 | ) | 81 | ) |
| 82 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
| 82 | 83 | ||
| 83 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 84 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 84 | r""" | 85 | r""" |
| @@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 160 | return torch.device(module._hf_hook.execution_device) | 161 | return torch.device(module._hf_hook.execution_device) |
| 161 | return self.device | 162 | return self.device |
| 162 | 163 | ||
| 163 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | 164 | def check_inputs( |
| 164 | if isinstance(prompt, str): | 165 | self, |
| 166 | prompt: Union[str, List[str], List[int], List[List[int]]], | ||
| 167 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | ||
| 168 | width: Optional[int], | ||
| 169 | height: Optional[int], | ||
| 170 | strength: float, | ||
| 171 | callback_steps: Optional[int] | ||
| 172 | ): | ||
| 173 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | ||
| 165 | prompt = [prompt] | 174 | prompt = [prompt] |
| 166 | 175 | ||
| 167 | if negative_prompt is None: | 176 | if negative_prompt is None: |
| 168 | negative_prompt = "" | 177 | negative_prompt = "" |
| 169 | 178 | ||
| 170 | if isinstance(negative_prompt, str): | 179 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): |
| 171 | negative_prompt = [negative_prompt] * len(prompt) | 180 | negative_prompt = [negative_prompt] * len(prompt) |
| 172 | 181 | ||
| 173 | if not isinstance(prompt, list): | 182 | if not isinstance(prompt, list): |
| @@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 196 | 205 | ||
| 197 | return prompt, negative_prompt | 206 | return prompt, negative_prompt |
| 198 | 207 | ||
| 199 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): | 208 | def encode_prompt( |
| 200 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 209 | self, |
| 210 | prompt: Union[List[str], List[List[int]]], | ||
| 211 | negative_prompt: Union[List[str], List[List[int]]], | ||
| 212 | num_images_per_prompt: int, | ||
| 213 | do_classifier_free_guidance: bool, | ||
| 214 | device | ||
| 215 | ): | ||
| 216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | ||
| 201 | text_input_ids *= num_images_per_prompt | 217 | text_input_ids *= num_images_per_prompt |
| 202 | 218 | ||
| 203 | if do_classifier_free_guidance: | 219 | if do_classifier_free_guidance: |
| 204 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( |
| 221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | ||
| 205 | unconditional_input_ids *= num_images_per_prompt | 222 | unconditional_input_ids *= num_images_per_prompt |
| 206 | text_input_ids = unconditional_input_ids + text_input_ids | 223 | text_input_ids = unconditional_input_ids + text_input_ids |
| 207 | 224 | ||
| @@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 314 | @torch.no_grad() | 331 | @torch.no_grad() |
| 315 | def __call__( | 332 | def __call__( |
| 316 | self, | 333 | self, |
| 317 | prompt: Union[str, List[str], List[List[str]]], | 334 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 335 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
| 319 | num_images_per_prompt: Optional[int] = 1, | 336 | num_images_per_prompt: Optional[int] = 1, |
| 320 | strength: float = 0.8, | 337 | strength: float = 0.8, |
| 321 | height: Optional[int] = 768, | 338 | height: Optional[int] = None, |
| 322 | width: Optional[int] = 768, | 339 | width: Optional[int] = None, |
| 323 | num_inference_steps: Optional[int] = 50, | 340 | num_inference_steps: Optional[int] = 50, |
| 324 | guidance_scale: Optional[float] = 7.5, | 341 | guidance_scale: Optional[float] = 7.5, |
| 325 | eta: Optional[float] = 0.0, | 342 | eta: Optional[float] = 0.0, |
| @@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 379 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 396 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
| 380 | (nsfw) content, according to the `safety_checker`. | 397 | (nsfw) content, according to the `safety_checker`. |
| 381 | """ | 398 | """ |
| 399 | # 0. Default height and width to unet | ||
| 400 | height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
| 401 | width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
| 382 | 402 | ||
| 383 | # 1. Check inputs. Raise error if not correct | 403 | # 1. Check inputs. Raise error if not correct |
| 384 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 404 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |
