diff options
author | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
commit | 728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch) | |
tree | 9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /pipelines | |
parent | Upstream updates; better handling of textual embedding (diff) | |
download | textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2 textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip |
Implemented extended prompt limit
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 96 |
1 files changed, 55 insertions, 41 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8b08a6f..b68b028 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -24,6 +24,22 @@ def preprocess(image, w, h): | |||
24 | return 2.0 * image - 1.0 | 24 | return 2.0 * image - 1.0 |
25 | 25 | ||
26 | 26 | ||
27 | def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): | ||
28 | if isinstance(prompt, str): | ||
29 | prompt = [prompt] * batch_size | ||
30 | |||
31 | if isinstance(prompt, list) and isinstance(prompt[0], str): | ||
32 | prompt = [[p] for p in prompt] | ||
33 | |||
34 | if isinstance(prompt, list) and isinstance(prompt[0], list): | ||
35 | prompt_size = prompt_size or max([len(p) for p in prompt]) | ||
36 | prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] | ||
37 | else: | ||
38 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
39 | |||
40 | return prompt_size, prompt | ||
41 | |||
42 | |||
27 | class VlpnStableDiffusion(DiffusionPipeline): | 43 | class VlpnStableDiffusion(DiffusionPipeline): |
28 | def __init__( | 44 | def __init__( |
29 | self, | 45 | self, |
@@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
85 | # set slice_size = `None` to disable `attention slicing` | 101 | # set slice_size = `None` to disable `attention slicing` |
86 | self.enable_attention_slicing(None) | 102 | self.enable_attention_slicing(None) |
87 | 103 | ||
104 | def embeddings_for_prompt(self, prompt: List[List[str]]): | ||
105 | text_embeddings = [] | ||
106 | |||
107 | for p in prompt: | ||
108 | inputs = self.tokenizer( | ||
109 | p, | ||
110 | padding="max_length", | ||
111 | max_length=self.tokenizer.model_max_length, | ||
112 | return_tensors="pt", | ||
113 | ) | ||
114 | input_ids = inputs.input_ids | ||
115 | |||
116 | if input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
117 | removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) | ||
118 | logger.warning( | ||
119 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
120 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
121 | ) | ||
122 | print(f"Too many tokens: {removed_text}") | ||
123 | input_ids = input_ids[:, : self.tokenizer.model_max_length] | ||
124 | |||
125 | embeddings = self.text_encoder(input_ids.to(self.device))[0] | ||
126 | embeddings = embeddings.reshape((1, -1, 768)) | ||
127 | text_embeddings.append(embeddings) | ||
128 | |||
129 | text_embeddings = torch.cat(text_embeddings) | ||
130 | return text_embeddings | ||
131 | |||
88 | @torch.no_grad() | 132 | @torch.no_grad() |
89 | def __call__( | 133 | def __call__( |
90 | self, | 134 | self, |
91 | prompt: Union[str, List[str]], | 135 | prompt: Union[str, List[str], List[List[str]]], |
92 | negative_prompt: Optional[Union[str, List[str]]] = None, | 136 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
93 | strength: float = 0.8, | 137 | strength: float = 0.8, |
94 | height: Optional[int] = 512, | 138 | height: Optional[int] = 512, |
95 | width: Optional[int] = 512, | 139 | width: Optional[int] = 512, |
@@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
151 | (nsfw) content, according to the `safety_checker`. | 195 | (nsfw) content, according to the `safety_checker`. |
152 | """ | 196 | """ |
153 | 197 | ||
154 | if isinstance(prompt, str): | 198 | prompt_size, prompt = normalize_prompt(prompt) |
155 | batch_size = 1 | 199 | batch_size = len(prompt) |
156 | elif isinstance(prompt, list): | 200 | _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) |
157 | batch_size = len(prompt) | 201 | |
158 | else: | 202 | if len(negative_prompt) != batch_size: |
159 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 203 | raise ValueError( |
160 | 204 | f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") | |
161 | if negative_prompt is None: | ||
162 | negative_prompt = [""] * batch_size | ||
163 | elif isinstance(negative_prompt, str): | ||
164 | negative_prompt = [negative_prompt] * batch_size | ||
165 | elif isinstance(negative_prompt, list): | ||
166 | if len(negative_prompt) != batch_size: | ||
167 | raise ValueError( | ||
168 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
169 | else: | ||
170 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
171 | 205 | ||
172 | if height % 8 != 0 or width % 8 != 0: | 206 | if height % 8 != 0 or width % 8 != 0: |
173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 207 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
@@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
179 | self.scheduler.set_timesteps(num_inference_steps) | 213 | self.scheduler.set_timesteps(num_inference_steps) |
180 | 214 | ||
181 | # get prompt text embeddings | 215 | # get prompt text embeddings |
182 | text_inputs = self.tokenizer( | 216 | text_embeddings = self.embeddings_for_prompt(prompt) |
183 | prompt, | ||
184 | padding="max_length", | ||
185 | max_length=self.tokenizer.model_max_length, | ||
186 | return_tensors="pt", | ||
187 | ) | ||
188 | text_input_ids = text_inputs.input_ids | ||
189 | |||
190 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
191 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) | ||
192 | logger.warning( | ||
193 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
194 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
195 | ) | ||
196 | print(f"Too many tokens: {removed_text}") | ||
197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | ||
198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | ||
199 | 217 | ||
200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 218 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 219 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
203 | do_classifier_free_guidance = guidance_scale > 1.0 | 221 | do_classifier_free_guidance = guidance_scale > 1.0 |
204 | # get unconditional embeddings for classifier free guidance | 222 | # get unconditional embeddings for classifier free guidance |
205 | if do_classifier_free_guidance: | 223 | if do_classifier_free_guidance: |
206 | max_length = text_input_ids.shape[-1] | 224 | uncond_embeddings = self.embeddings_for_prompt(negative_prompt) |
207 | uncond_input = self.tokenizer( | ||
208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | ||
209 | ) | ||
210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
211 | 225 | ||
212 | # For classifier free guidance, we need to do two forward passes. | 226 | # For classifier free guidance, we need to do two forward passes. |
213 | # Here we concatenate the unconditional and text embeddings into a single batch | 227 | # Here we concatenate the unconditional and text embeddings into a single batch |