summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-18 15:23:40 +0200
committerVolpeon <git@volpeon.ink>2022-10-18 15:23:40 +0200
commit306f2bfb620e6882737658bd3694c79365d75e4b (patch)
tree8b461c4360b9baa5758c2af0100348f14df8c76d /pipelines
parentImplemented extended prompt limit (diff)
downloadtextual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.gz
textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.bz2
textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.zip
Improved prompt handling
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py72
1 files changed, 17 insertions, 55 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index b68b028..3da0169 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler 14from schedulers.scheduling_euler_a import EulerAScheduler
15from models.clip.prompt import PromptProcessor
15 16
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17 18
@@ -24,22 +25,6 @@ def preprocess(image, w, h):
24 return 2.0 * image - 1.0 25 return 2.0 * image - 1.0
25 26
26 27
27def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None):
28 if isinstance(prompt, str):
29 prompt = [prompt] * batch_size
30
31 if isinstance(prompt, list) and isinstance(prompt[0], str):
32 prompt = [[p] for p in prompt]
33
34 if isinstance(prompt, list) and isinstance(prompt[0], list):
35 prompt_size = prompt_size or max([len(p) for p in prompt])
36 prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt]
37 else:
38 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
39
40 return prompt_size, prompt
41
42
43class VlpnStableDiffusion(DiffusionPipeline): 28class VlpnStableDiffusion(DiffusionPipeline):
44 def __init__( 29 def __init__(
45 self, 30 self,
@@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
66 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
67 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
68 53
54 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
55
69 self.register_modules( 56 self.register_modules(
70 vae=vae, 57 vae=vae,
71 text_encoder=text_encoder, 58 text_encoder=text_encoder,
@@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
101 # set slice_size = `None` to disable `attention slicing` 88 # set slice_size = `None` to disable `attention slicing`
102 self.enable_attention_slicing(None) 89 self.enable_attention_slicing(None)
103 90
104 def embeddings_for_prompt(self, prompt: List[List[str]]):
105 text_embeddings = []
106
107 for p in prompt:
108 inputs = self.tokenizer(
109 p,
110 padding="max_length",
111 max_length=self.tokenizer.model_max_length,
112 return_tensors="pt",
113 )
114 input_ids = inputs.input_ids
115
116 if input_ids.shape[-1] > self.tokenizer.model_max_length:
117 removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:])
118 logger.warning(
119 "The following part of your input was truncated because CLIP can only handle sequences up to"
120 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
121 )
122 print(f"Too many tokens: {removed_text}")
123 input_ids = input_ids[:, : self.tokenizer.model_max_length]
124
125 embeddings = self.text_encoder(input_ids.to(self.device))[0]
126 embeddings = embeddings.reshape((1, -1, 768))
127 text_embeddings.append(embeddings)
128
129 text_embeddings = torch.cat(text_embeddings)
130 return text_embeddings
131
132 @torch.no_grad() 91 @torch.no_grad()
133 def __call__( 92 def __call__(
134 self, 93 self,
@@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 (nsfw) content, according to the `safety_checker`. 154 (nsfw) content, according to the `safety_checker`.
196 """ 155 """
197 156
198 prompt_size, prompt = normalize_prompt(prompt) 157 if isinstance(prompt, str):
158 prompt = [prompt]
159
199 batch_size = len(prompt) 160 batch_size = len(prompt)
200 _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size)
201 161
202 if len(negative_prompt) != batch_size: 162 if isinstance(negative_prompt, str):
163 negative_prompt = [negative_prompt] * batch_size
164
165 if len(negative_prompt) != len(prompt):
203 raise ValueError( 166 raise ValueError(
204 f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") 167 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
205 168
206 if height % 8 != 0 or width % 8 != 0: 169 if height % 8 != 0 or width % 8 != 0:
207 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 170 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
213 self.scheduler.set_timesteps(num_inference_steps) 176 self.scheduler.set_timesteps(num_inference_steps)
214 177
215 # get prompt text embeddings 178 # get prompt text embeddings
216 text_embeddings = self.embeddings_for_prompt(prompt) 179 text_input_ids = self.prompt_processor.get_input_ids(prompt)
217 180
218 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 181 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
219 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 182 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
221 do_classifier_free_guidance = guidance_scale > 1.0 184 do_classifier_free_guidance = guidance_scale > 1.0
222 # get unconditional embeddings for classifier free guidance 185 # get unconditional embeddings for classifier free guidance
223 if do_classifier_free_guidance: 186 if do_classifier_free_guidance:
224 uncond_embeddings = self.embeddings_for_prompt(negative_prompt) 187 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt)
188 text_input_ids = unconditional_input_ids + text_input_ids
225 189
226 # For classifier free guidance, we need to do two forward passes. 190 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids)
227 # Here we concatenate the unconditional and text embeddings into a single batch 191 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids)
228 # to avoid doing two forward passes
229 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
230 192
231 offset = self.scheduler.config.get("steps_offset", 0) 193 offset = self.scheduler.config.get("steps_offset", 0)
232 init_timestep = num_inference_steps + offset 194 init_timestep = num_inference_steps + offset