summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py96
1 files changed, 55 insertions, 41 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8b08a6f..b68b028 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -24,6 +24,22 @@ def preprocess(image, w, h):
24 return 2.0 * image - 1.0 24 return 2.0 * image - 1.0
25 25
26 26
27def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None):
28 if isinstance(prompt, str):
29 prompt = [prompt] * batch_size
30
31 if isinstance(prompt, list) and isinstance(prompt[0], str):
32 prompt = [[p] for p in prompt]
33
34 if isinstance(prompt, list) and isinstance(prompt[0], list):
35 prompt_size = prompt_size or max([len(p) for p in prompt])
36 prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt]
37 else:
38 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
39
40 return prompt_size, prompt
41
42
27class VlpnStableDiffusion(DiffusionPipeline): 43class VlpnStableDiffusion(DiffusionPipeline):
28 def __init__( 44 def __init__(
29 self, 45 self,
@@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline):
85 # set slice_size = `None` to disable `attention slicing` 101 # set slice_size = `None` to disable `attention slicing`
86 self.enable_attention_slicing(None) 102 self.enable_attention_slicing(None)
87 103
104 def embeddings_for_prompt(self, prompt: List[List[str]]):
105 text_embeddings = []
106
107 for p in prompt:
108 inputs = self.tokenizer(
109 p,
110 padding="max_length",
111 max_length=self.tokenizer.model_max_length,
112 return_tensors="pt",
113 )
114 input_ids = inputs.input_ids
115
116 if input_ids.shape[-1] > self.tokenizer.model_max_length:
117 removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:])
118 logger.warning(
119 "The following part of your input was truncated because CLIP can only handle sequences up to"
120 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
121 )
122 print(f"Too many tokens: {removed_text}")
123 input_ids = input_ids[:, : self.tokenizer.model_max_length]
124
125 embeddings = self.text_encoder(input_ids.to(self.device))[0]
126 embeddings = embeddings.reshape((1, -1, 768))
127 text_embeddings.append(embeddings)
128
129 text_embeddings = torch.cat(text_embeddings)
130 return text_embeddings
131
88 @torch.no_grad() 132 @torch.no_grad()
89 def __call__( 133 def __call__(
90 self, 134 self,
91 prompt: Union[str, List[str]], 135 prompt: Union[str, List[str], List[List[str]]],
92 negative_prompt: Optional[Union[str, List[str]]] = None, 136 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None,
93 strength: float = 0.8, 137 strength: float = 0.8,
94 height: Optional[int] = 512, 138 height: Optional[int] = 512,
95 width: Optional[int] = 512, 139 width: Optional[int] = 512,
@@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
151 (nsfw) content, according to the `safety_checker`. 195 (nsfw) content, according to the `safety_checker`.
152 """ 196 """
153 197
154 if isinstance(prompt, str): 198 prompt_size, prompt = normalize_prompt(prompt)
155 batch_size = 1 199 batch_size = len(prompt)
156 elif isinstance(prompt, list): 200 _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size)
157 batch_size = len(prompt) 201
158 else: 202 if len(negative_prompt) != batch_size:
159 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 203 raise ValueError(
160 204 f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}")
161 if negative_prompt is None:
162 negative_prompt = [""] * batch_size
163 elif isinstance(negative_prompt, str):
164 negative_prompt = [negative_prompt] * batch_size
165 elif isinstance(negative_prompt, list):
166 if len(negative_prompt) != batch_size:
167 raise ValueError(
168 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
169 else:
170 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
171 205
172 if height % 8 != 0 or width % 8 != 0: 206 if height % 8 != 0 or width % 8 != 0:
173 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 207 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
179 self.scheduler.set_timesteps(num_inference_steps) 213 self.scheduler.set_timesteps(num_inference_steps)
180 214
181 # get prompt text embeddings 215 # get prompt text embeddings
182 text_inputs = self.tokenizer( 216 text_embeddings = self.embeddings_for_prompt(prompt)
183 prompt,
184 padding="max_length",
185 max_length=self.tokenizer.model_max_length,
186 return_tensors="pt",
187 )
188 text_input_ids = text_inputs.input_ids
189
190 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
191 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
192 logger.warning(
193 "The following part of your input was truncated because CLIP can only handle sequences up to"
194 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
195 )
196 print(f"Too many tokens: {removed_text}")
197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
199 217
200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 218 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 219 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 do_classifier_free_guidance = guidance_scale > 1.0 221 do_classifier_free_guidance = guidance_scale > 1.0
204 # get unconditional embeddings for classifier free guidance 222 # get unconditional embeddings for classifier free guidance
205 if do_classifier_free_guidance: 223 if do_classifier_free_guidance:
206 max_length = text_input_ids.shape[-1] 224 uncond_embeddings = self.embeddings_for_prompt(negative_prompt)
207 uncond_input = self.tokenizer(
208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
209 )
210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
211 225
212 # For classifier free guidance, we need to do two forward passes. 226 # For classifier free guidance, we need to do two forward passes.
213 # Here we concatenate the unconditional and text embeddings into a single batch 227 # Here we concatenate the unconditional and text embeddings into a single batch