diff options
author | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
commit | 306f2bfb620e6882737658bd3694c79365d75e4b (patch) | |
tree | 8b461c4360b9baa5758c2af0100348f14df8c76d /pipelines | |
parent | Implemented extended prompt limit (diff) | |
download | textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.gz textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.bz2 textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.zip |
Improved prompt handling
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 72 |
1 files changed, 17 insertions, 55 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b68b028..3da0169 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict | |||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from schedulers.scheduling_euler_a import EulerAScheduler | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
15 | from models.clip.prompt import PromptProcessor | ||
15 | 16 | ||
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
17 | 18 | ||
@@ -24,22 +25,6 @@ def preprocess(image, w, h): | |||
24 | return 2.0 * image - 1.0 | 25 | return 2.0 * image - 1.0 |
25 | 26 | ||
26 | 27 | ||
27 | def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): | ||
28 | if isinstance(prompt, str): | ||
29 | prompt = [prompt] * batch_size | ||
30 | |||
31 | if isinstance(prompt, list) and isinstance(prompt[0], str): | ||
32 | prompt = [[p] for p in prompt] | ||
33 | |||
34 | if isinstance(prompt, list) and isinstance(prompt[0], list): | ||
35 | prompt_size = prompt_size or max([len(p) for p in prompt]) | ||
36 | prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] | ||
37 | else: | ||
38 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
39 | |||
40 | return prompt_size, prompt | ||
41 | |||
42 | |||
43 | class VlpnStableDiffusion(DiffusionPipeline): | 28 | class VlpnStableDiffusion(DiffusionPipeline): |
44 | def __init__( | 29 | def __init__( |
45 | self, | 30 | self, |
@@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
66 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
67 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
68 | 53 | ||
54 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
55 | |||
69 | self.register_modules( | 56 | self.register_modules( |
70 | vae=vae, | 57 | vae=vae, |
71 | text_encoder=text_encoder, | 58 | text_encoder=text_encoder, |
@@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
101 | # set slice_size = `None` to disable `attention slicing` | 88 | # set slice_size = `None` to disable `attention slicing` |
102 | self.enable_attention_slicing(None) | 89 | self.enable_attention_slicing(None) |
103 | 90 | ||
104 | def embeddings_for_prompt(self, prompt: List[List[str]]): | ||
105 | text_embeddings = [] | ||
106 | |||
107 | for p in prompt: | ||
108 | inputs = self.tokenizer( | ||
109 | p, | ||
110 | padding="max_length", | ||
111 | max_length=self.tokenizer.model_max_length, | ||
112 | return_tensors="pt", | ||
113 | ) | ||
114 | input_ids = inputs.input_ids | ||
115 | |||
116 | if input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
117 | removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) | ||
118 | logger.warning( | ||
119 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
120 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
121 | ) | ||
122 | print(f"Too many tokens: {removed_text}") | ||
123 | input_ids = input_ids[:, : self.tokenizer.model_max_length] | ||
124 | |||
125 | embeddings = self.text_encoder(input_ids.to(self.device))[0] | ||
126 | embeddings = embeddings.reshape((1, -1, 768)) | ||
127 | text_embeddings.append(embeddings) | ||
128 | |||
129 | text_embeddings = torch.cat(text_embeddings) | ||
130 | return text_embeddings | ||
131 | |||
132 | @torch.no_grad() | 91 | @torch.no_grad() |
133 | def __call__( | 92 | def __call__( |
134 | self, | 93 | self, |
@@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
195 | (nsfw) content, according to the `safety_checker`. | 154 | (nsfw) content, according to the `safety_checker`. |
196 | """ | 155 | """ |
197 | 156 | ||
198 | prompt_size, prompt = normalize_prompt(prompt) | 157 | if isinstance(prompt, str): |
158 | prompt = [prompt] | ||
159 | |||
199 | batch_size = len(prompt) | 160 | batch_size = len(prompt) |
200 | _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) | ||
201 | 161 | ||
202 | if len(negative_prompt) != batch_size: | 162 | if isinstance(negative_prompt, str): |
163 | negative_prompt = [negative_prompt] * batch_size | ||
164 | |||
165 | if len(negative_prompt) != len(prompt): | ||
203 | raise ValueError( | 166 | raise ValueError( |
204 | f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") | 167 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") |
205 | 168 | ||
206 | if height % 8 != 0 or width % 8 != 0: | 169 | if height % 8 != 0 or width % 8 != 0: |
207 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 170 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
@@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
213 | self.scheduler.set_timesteps(num_inference_steps) | 176 | self.scheduler.set_timesteps(num_inference_steps) |
214 | 177 | ||
215 | # get prompt text embeddings | 178 | # get prompt text embeddings |
216 | text_embeddings = self.embeddings_for_prompt(prompt) | 179 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
217 | 180 | ||
218 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 181 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
219 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 182 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
221 | do_classifier_free_guidance = guidance_scale > 1.0 | 184 | do_classifier_free_guidance = guidance_scale > 1.0 |
222 | # get unconditional embeddings for classifier free guidance | 185 | # get unconditional embeddings for classifier free guidance |
223 | if do_classifier_free_guidance: | 186 | if do_classifier_free_guidance: |
224 | uncond_embeddings = self.embeddings_for_prompt(negative_prompt) | 187 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) |
188 | text_input_ids = unconditional_input_ids + text_input_ids | ||
225 | 189 | ||
226 | # For classifier free guidance, we need to do two forward passes. | 190 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) |
227 | # Here we concatenate the unconditional and text embeddings into a single batch | 191 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) |
228 | # to avoid doing two forward passes | ||
229 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
230 | 192 | ||
231 | offset = self.scheduler.config.get("steps_offset", 0) | 193 | offset = self.scheduler.config.get("steps_offset", 0) |
232 | init_timestep = num_inference_steps + offset | 194 | init_timestep = num_inference_steps + offset |