From 306f2bfb620e6882737658bd3694c79365d75e4b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 15:23:40 +0200 Subject: Improved prompt handling --- .../stable_diffusion/vlpn_stable_diffusion.py | 72 +++++----------------- 1 file changed, 17 insertions(+), 55 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b68b028..3da0169 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel +from transformers import CLIPTextModel, CLIPTokenizer from schedulers.scheduling_euler_a import EulerAScheduler +from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -24,22 +25,6 @@ def preprocess(image, w, h): return 2.0 * image - 1.0 -def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): - if isinstance(prompt, str): - prompt = [prompt] * batch_size - - if isinstance(prompt, list) and isinstance(prompt[0], str): - prompt = [[p] for p in prompt] - - if isinstance(prompt, list) and isinstance(prompt[0], list): - prompt_size = prompt_size or max([len(p) for p in prompt]) - prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - return prompt_size, prompt - - class VlpnStableDiffusion(DiffusionPipeline): def __init__( self, @@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.prompt_processor = PromptProcessor(tokenizer, text_encoder) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def embeddings_for_prompt(self, prompt: List[List[str]]): - text_embeddings = [] - - for p in prompt: - inputs = self.tokenizer( - p, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - input_ids = inputs.input_ids - - if input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - print(f"Too many tokens: {removed_text}") - input_ids = input_ids[:, : self.tokenizer.model_max_length] - - embeddings = self.text_encoder(input_ids.to(self.device))[0] - embeddings = embeddings.reshape((1, -1, 768)) - text_embeddings.append(embeddings) - - text_embeddings = torch.cat(text_embeddings) - return text_embeddings - @torch.no_grad() def __call__( self, @@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ - prompt_size, prompt = normalize_prompt(prompt) + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) - _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) - if len(negative_prompt) != batch_size: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + if len(negative_prompt) != len(prompt): raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # get prompt text embeddings - text_embeddings = self.embeddings_for_prompt(prompt) + text_input_ids = self.prompt_processor.get_input_ids(prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_embeddings = self.embeddings_for_prompt(negative_prompt) + unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) + text_input_ids = unconditional_input_ids + text_input_ids - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) + text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) offset = self.scheduler.config.get("steps_offset", 0) init_timestep = num_inference_steps + offset -- cgit v1.2.3-54-g00ecf