From 728dfcf57c30f40236b3a00d7380c4e0057cacb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 22:08:58 +0200 Subject: Implemented extended prompt limit --- .../stable_diffusion/vlpn_stable_diffusion.py | 96 +++++++++++++--------- 1 file changed, 55 insertions(+), 41 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8b08a6f..b68b028 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -24,6 +24,22 @@ def preprocess(image, w, h): return 2.0 * image - 1.0 +def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): + if isinstance(prompt, str): + prompt = [prompt] * batch_size + + if isinstance(prompt, list) and isinstance(prompt[0], str): + prompt = [[p] for p in prompt] + + if isinstance(prompt, list) and isinstance(prompt[0], list): + prompt_size = prompt_size or max([len(p) for p in prompt]) + prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + return prompt_size, prompt + + class VlpnStableDiffusion(DiffusionPipeline): def __init__( self, @@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def embeddings_for_prompt(self, prompt: List[List[str]]): + text_embeddings = [] + + for p in prompt: + inputs = self.tokenizer( + p, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = inputs.input_ids + + if input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + print(f"Too many tokens: {removed_text}") + input_ids = input_ids[:, : self.tokenizer.model_max_length] + + embeddings = self.text_encoder(input_ids.to(self.device))[0] + embeddings = embeddings.reshape((1, -1, 768)) + text_embeddings.append(embeddings) + + text_embeddings = torch.cat(text_embeddings) + return text_embeddings + @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str], List[List[str]]], + negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, strength: float = 0.8, height: Optional[int] = 512, width: Optional[int] = 512, @@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - elif isinstance(negative_prompt, list): - if len(negative_prompt) != batch_size: - raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") - else: - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + prompt_size, prompt = normalize_prompt(prompt) + batch_size = len(prompt) + _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) + + if len(negative_prompt) != batch_size: + raise ValueError( + f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - print(f"Too many tokens: {removed_text}") - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.embeddings_for_prompt(prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.embeddings_for_prompt(negative_prompt) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch -- cgit v1.2.3-54-g00ecf