summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py75
1 files changed, 52 insertions, 23 deletions
diff --git a/infer.py b/infer.py
index 63b16d8..650c119 100644
--- a/infer.py
+++ b/infer.py
@@ -23,7 +23,8 @@ default_args = {
23 "model": None, 23 "model": None,
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "fp32", 25 "precision": "fp32",
26 "embeddings_dir": "embeddings", 26 "ti_embeddings_dir": "embeddings_ti",
27 "ag_embeddings_dir": "embeddings_ag",
27 "output_dir": "output/inference", 28 "output_dir": "output/inference",
28 "config": None, 29 "config": None,
29} 30}
@@ -73,7 +74,11 @@ def create_args_parser():
73 choices=["fp32", "fp16", "bf16"], 74 choices=["fp32", "fp16", "bf16"],
74 ) 75 )
75 parser.add_argument( 76 parser.add_argument(
76 "--embeddings_dir", 77 "--ti_embeddings_dir",
78 type=str,
79 )
80 parser.add_argument(
81 "--ag_embeddings_dir",
77 type=str, 82 type=str,
78 ) 83 )
79 parser.add_argument( 84 parser.add_argument(
@@ -167,42 +172,63 @@ def save_args(basepath, args, extra={}):
167 json.dump(info, f, indent=4) 172 json.dump(info, f, indent=4)
168 173
169 174
170def load_embeddings(tokenizer, text_encoder, embeddings_dir): 175def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
176 print(f"Loading Textual Inversion embeddings")
177
171 embeddings_dir = Path(embeddings_dir) 178 embeddings_dir = Path(embeddings_dir)
172 embeddings_dir.mkdir(parents=True, exist_ok=True) 179 embeddings_dir.mkdir(parents=True, exist_ok=True)
173 180
174 for file in embeddings_dir.iterdir(): 181 for file in embeddings_dir.iterdir():
175 placeholder_token = file.stem 182 if file.is_file():
183 placeholder_token = file.stem
176 184
177 num_added_tokens = tokenizer.add_tokens(placeholder_token) 185 num_added_tokens = tokenizer.add_tokens(placeholder_token)
178 if num_added_tokens == 0: 186 if num_added_tokens == 0:
179 raise ValueError( 187 raise ValueError(
180 f"The tokenizer already contains the token {placeholder_token}. Please pass a different" 188 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
181 " `placeholder_token` that is not already in the tokenizer." 189 " `placeholder_token` that is not already in the tokenizer."
182 ) 190 )
183 191
184 text_encoder.resize_token_embeddings(len(tokenizer)) 192 text_encoder.resize_token_embeddings(len(tokenizer))
185 193
186 token_embeds = text_encoder.get_input_embeddings().weight.data 194 token_embeds = text_encoder.get_input_embeddings().weight.data
187 195
188 for file in embeddings_dir.iterdir(): 196 for file in embeddings_dir.iterdir():
189 placeholder_token = file.stem 197 if file.is_file():
190 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) 198 placeholder_token = file.stem
199 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
200
201 data = torch.load(file, map_location="cpu")
202
203 assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
204
205 emb = next(iter(data.values()))
206 if len(emb.shape) == 1:
207 emb = emb.unsqueeze(0)
191 208
192 data = torch.load(file, map_location="cpu") 209 token_embeds[placeholder_token_id] = emb
193 210
194 assert len(data.keys()) == 1, 'embedding file has multiple terms in it' 211 print(f"Loaded {placeholder_token}")
195 212
196 emb = next(iter(data.values()))
197 if len(emb.shape) == 1:
198 emb = emb.unsqueeze(0)
199 213
200 token_embeds[placeholder_token_id] = emb 214def load_embeddings_ag(pipeline, embeddings_dir):
215 print(f"Loading Aesthetic Gradient embeddings")
201 216
202 print(f"Loaded embedding: {placeholder_token}") 217 embeddings_dir = Path(embeddings_dir)
218 embeddings_dir.mkdir(parents=True, exist_ok=True)
219
220 for file in embeddings_dir.iterdir():
221 if file.is_file():
222 placeholder_token = file.stem
203 223
224 data = torch.load(file, map_location="cpu")
204 225
205def create_pipeline(model, scheduler, embeddings_dir, dtype): 226 pipeline.add_aesthetic_gradient_embedding(placeholder_token, data)
227
228 print(f"Loaded {placeholder_token}")
229
230
231def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype):
206 print("Loading Stable Diffusion pipeline...") 232 print("Loading Stable Diffusion pipeline...")
207 233
208 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 234 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -210,7 +236,7 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype):
210 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 236 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
211 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 237 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
212 238
213 load_embeddings(tokenizer, text_encoder, embeddings_dir) 239 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir)
214 240
215 if scheduler == "plms": 241 if scheduler == "plms":
216 scheduler = PNDMScheduler( 242 scheduler = PNDMScheduler(
@@ -236,10 +262,13 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype):
236 tokenizer=tokenizer, 262 tokenizer=tokenizer,
237 scheduler=scheduler, 263 scheduler=scheduler,
238 ) 264 )
265 pipeline.aesthetic_gradient_iters = 30
239 pipeline.to("cuda") 266 pipeline.to("cuda")
240 267
241 print("Pipeline loaded.") 268 print("Pipeline loaded.")
242 269
270 load_embeddings_ag(pipeline, ag_embeddings_dir)
271
243 return pipeline 272 return pipeline
244 273
245 274
@@ -259,7 +288,7 @@ def generate(output_dir, pipeline, args):
259 else: 288 else:
260 init_image = None 289 init_image = None
261 290
262 with torch.autocast("cuda"), torch.inference_mode(): 291 with torch.autocast("cuda"):
263 for i in range(args.batch_num): 292 for i in range(args.batch_num):
264 pipeline.set_progress_bar_config( 293 pipeline.set_progress_bar_config(
265 desc=f"Batch {i + 1} of {args.batch_num}", 294 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -337,7 +366,7 @@ def main():
337 output_dir = Path(args.output_dir) 366 output_dir = Path(args.output_dir)
338 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 367 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
339 368
340 pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) 369 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype)
341 cmd_parser = create_cmd_parser() 370 cmd_parser = create_cmd_parser()
342 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 371 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
343 cmd_prompt.cmdloop() 372 cmd_prompt.cmdloop()