diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
commit | 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch) | |
tree | d275e13506ca737efef18dc6dffa05f4e0d6759f /pipelines | |
parent | Improved aspect ratio bucketing (diff) | |
download | textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2 textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip |
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
79 | unet=unet, | 79 | unet=unet, |
80 | scheduler=scheduler, | 80 | scheduler=scheduler, |
81 | ) | 81 | ) |
82 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
82 | 83 | ||
83 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 84 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
84 | r""" | 85 | r""" |
@@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
160 | return torch.device(module._hf_hook.execution_device) | 161 | return torch.device(module._hf_hook.execution_device) |
161 | return self.device | 162 | return self.device |
162 | 163 | ||
163 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | 164 | def check_inputs( |
164 | if isinstance(prompt, str): | 165 | self, |
166 | prompt: Union[str, List[str], List[int], List[List[int]]], | ||
167 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | ||
168 | width: Optional[int], | ||
169 | height: Optional[int], | ||
170 | strength: float, | ||
171 | callback_steps: Optional[int] | ||
172 | ): | ||
173 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | ||
165 | prompt = [prompt] | 174 | prompt = [prompt] |
166 | 175 | ||
167 | if negative_prompt is None: | 176 | if negative_prompt is None: |
168 | negative_prompt = "" | 177 | negative_prompt = "" |
169 | 178 | ||
170 | if isinstance(negative_prompt, str): | 179 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): |
171 | negative_prompt = [negative_prompt] * len(prompt) | 180 | negative_prompt = [negative_prompt] * len(prompt) |
172 | 181 | ||
173 | if not isinstance(prompt, list): | 182 | if not isinstance(prompt, list): |
@@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
196 | 205 | ||
197 | return prompt, negative_prompt | 206 | return prompt, negative_prompt |
198 | 207 | ||
199 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): | 208 | def encode_prompt( |
200 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 209 | self, |
210 | prompt: Union[List[str], List[List[int]]], | ||
211 | negative_prompt: Union[List[str], List[List[int]]], | ||
212 | num_images_per_prompt: int, | ||
213 | do_classifier_free_guidance: bool, | ||
214 | device | ||
215 | ): | ||
216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | ||
201 | text_input_ids *= num_images_per_prompt | 217 | text_input_ids *= num_images_per_prompt |
202 | 218 | ||
203 | if do_classifier_free_guidance: | 219 | if do_classifier_free_guidance: |
204 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( |
221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | ||
205 | unconditional_input_ids *= num_images_per_prompt | 222 | unconditional_input_ids *= num_images_per_prompt |
206 | text_input_ids = unconditional_input_ids + text_input_ids | 223 | text_input_ids = unconditional_input_ids + text_input_ids |
207 | 224 | ||
@@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
314 | @torch.no_grad() | 331 | @torch.no_grad() |
315 | def __call__( | 332 | def __call__( |
316 | self, | 333 | self, |
317 | prompt: Union[str, List[str], List[List[str]]], | 334 | prompt: Union[str, List[str], List[int], List[List[int]]], |
318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 335 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
319 | num_images_per_prompt: Optional[int] = 1, | 336 | num_images_per_prompt: Optional[int] = 1, |
320 | strength: float = 0.8, | 337 | strength: float = 0.8, |
321 | height: Optional[int] = 768, | 338 | height: Optional[int] = None, |
322 | width: Optional[int] = 768, | 339 | width: Optional[int] = None, |
323 | num_inference_steps: Optional[int] = 50, | 340 | num_inference_steps: Optional[int] = 50, |
324 | guidance_scale: Optional[float] = 7.5, | 341 | guidance_scale: Optional[float] = 7.5, |
325 | eta: Optional[float] = 0.0, | 342 | eta: Optional[float] = 0.0, |
@@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
379 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 396 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
380 | (nsfw) content, according to the `safety_checker`. | 397 | (nsfw) content, according to the `safety_checker`. |
381 | """ | 398 | """ |
399 | # 0. Default height and width to unet | ||
400 | height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
401 | width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
382 | 402 | ||
383 | # 1. Check inputs. Raise error if not correct | 403 | # 1. Check inputs. Raise error if not correct |
384 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 404 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |