From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- .../stable_diffusion/vlpn_stable_diffusion.py | 40 ++++++++++++++++------ 1 file changed, 30 insertions(+), 10 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): unet=unet, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): return torch.device(module._hf_hook.execution_device) return self.device - def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): - if isinstance(prompt, str): + def check_inputs( + self, + prompt: Union[str, List[str], List[int], List[List[int]]], + negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], + width: Optional[int], + height: Optional[int], + strength: float, + callback_steps: Optional[int] + ): + if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): prompt = [prompt] if negative_prompt is None: negative_prompt = "" - if isinstance(negative_prompt, str): + if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): negative_prompt = [negative_prompt] * len(prompt) if not isinstance(prompt, list): @@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): return prompt, negative_prompt - def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): - text_input_ids = self.prompt_processor.get_input_ids(prompt) + def encode_prompt( + self, + prompt: Union[List[str], List[List[int]]], + negative_prompt: Union[List[str], List[List[int]]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + device + ): + text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt text_input_ids *= num_images_per_prompt if do_classifier_free_guidance: - unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) + unconditional_input_ids = self.prompt_processor.get_input_ids( + negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids @@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str], List[List[str]]], - negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, + prompt: Union[str, List[str], List[int], List[List[int]]], + negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, num_images_per_prompt: Optional[int] = 1, strength: float = 0.8, - height: Optional[int] = 768, - width: Optional[int] = 768, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, @@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) -- cgit v1.2.3-70-g09d2