summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py40
1 files changed, 30 insertions, 10 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 53b5eea..cb300d1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
79 unet=unet, 79 unet=unet,
80 scheduler=scheduler, 80 scheduler=scheduler,
81 ) 81 )
82 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82 83
83 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 84 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
84 r""" 85 r"""
@@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline):
160 return torch.device(module._hf_hook.execution_device) 161 return torch.device(module._hf_hook.execution_device)
161 return self.device 162 return self.device
162 163
163 def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): 164 def check_inputs(
164 if isinstance(prompt, str): 165 self,
166 prompt: Union[str, List[str], List[int], List[List[int]]],
167 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]],
168 width: Optional[int],
169 height: Optional[int],
170 strength: float,
171 callback_steps: Optional[int]
172 ):
173 if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)):
165 prompt = [prompt] 174 prompt = [prompt]
166 175
167 if negative_prompt is None: 176 if negative_prompt is None:
168 negative_prompt = "" 177 negative_prompt = ""
169 178
170 if isinstance(negative_prompt, str): 179 if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)):
171 negative_prompt = [negative_prompt] * len(prompt) 180 negative_prompt = [negative_prompt] * len(prompt)
172 181
173 if not isinstance(prompt, list): 182 if not isinstance(prompt, list):
@@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
196 205
197 return prompt, negative_prompt 206 return prompt, negative_prompt
198 207
199 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): 208 def encode_prompt(
200 text_input_ids = self.prompt_processor.get_input_ids(prompt) 209 self,
210 prompt: Union[List[str], List[List[int]]],
211 negative_prompt: Union[List[str], List[List[int]]],
212 num_images_per_prompt: int,
213 do_classifier_free_guidance: bool,
214 device
215 ):
216 text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt
201 text_input_ids *= num_images_per_prompt 217 text_input_ids *= num_images_per_prompt
202 218
203 if do_classifier_free_guidance: 219 if do_classifier_free_guidance:
204 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) 220 unconditional_input_ids = self.prompt_processor.get_input_ids(
221 negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt
205 unconditional_input_ids *= num_images_per_prompt 222 unconditional_input_ids *= num_images_per_prompt
206 text_input_ids = unconditional_input_ids + text_input_ids 223 text_input_ids = unconditional_input_ids + text_input_ids
207 224
@@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
314 @torch.no_grad() 331 @torch.no_grad()
315 def __call__( 332 def __call__(
316 self, 333 self,
317 prompt: Union[str, List[str], List[List[str]]], 334 prompt: Union[str, List[str], List[int], List[List[int]]],
318 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, 335 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None,
319 num_images_per_prompt: Optional[int] = 1, 336 num_images_per_prompt: Optional[int] = 1,
320 strength: float = 0.8, 337 strength: float = 0.8,
321 height: Optional[int] = 768, 338 height: Optional[int] = None,
322 width: Optional[int] = 768, 339 width: Optional[int] = None,
323 num_inference_steps: Optional[int] = 50, 340 num_inference_steps: Optional[int] = 50,
324 guidance_scale: Optional[float] = 7.5, 341 guidance_scale: Optional[float] = 7.5,
325 eta: Optional[float] = 0.0, 342 eta: Optional[float] = 0.0,
@@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
379 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 396 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
380 (nsfw) content, according to the `safety_checker`. 397 (nsfw) content, according to the `safety_checker`.
381 """ 398 """
399 # 0. Default height and width to unet
400 height = height or self.unet.config.sample_size * self.vae_scale_factor
401 width = width or self.unet.config.sample_size * self.vae_scale_factor
382 402
383 # 1. Check inputs. Raise error if not correct 403 # 1. Check inputs. Raise error if not correct
384 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) 404 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps)