diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 32 |
1 files changed, 4 insertions, 28 deletions
@@ -24,7 +24,6 @@ default_args = { | |||
24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
25 | "precision": "fp32", | 25 | "precision": "fp32", |
26 | "ti_embeddings_dir": "embeddings_ti", | 26 | "ti_embeddings_dir": "embeddings_ti", |
27 | "ag_embeddings_dir": "embeddings_ag", | ||
28 | "output_dir": "output/inference", | 27 | "output_dir": "output/inference", |
29 | "config": None, | 28 | "config": None, |
30 | } | 29 | } |
@@ -78,10 +77,6 @@ def create_args_parser(): | |||
78 | type=str, | 77 | type=str, |
79 | ) | 78 | ) |
80 | parser.add_argument( | 79 | parser.add_argument( |
81 | "--ag_embeddings_dir", | ||
82 | type=str, | ||
83 | ) | ||
84 | parser.add_argument( | ||
85 | "--output_dir", | 80 | "--output_dir", |
86 | type=str, | 81 | type=str, |
87 | ) | 82 | ) |
@@ -211,24 +206,7 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
211 | print(f"Loaded {placeholder_token}") | 206 | print(f"Loaded {placeholder_token}") |
212 | 207 | ||
213 | 208 | ||
214 | def load_embeddings_ag(pipeline, embeddings_dir): | 209 | def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): |
215 | print(f"Loading Aesthetic Gradient embeddings") | ||
216 | |||
217 | embeddings_dir = Path(embeddings_dir) | ||
218 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
219 | |||
220 | for file in embeddings_dir.iterdir(): | ||
221 | if file.is_file(): | ||
222 | placeholder_token = file.stem | ||
223 | |||
224 | data = torch.load(file, map_location="cpu") | ||
225 | |||
226 | pipeline.add_aesthetic_gradient_embedding(placeholder_token, data) | ||
227 | |||
228 | print(f"Loaded {placeholder_token}") | ||
229 | |||
230 | |||
231 | def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype): | ||
232 | print("Loading Stable Diffusion pipeline...") | 210 | print("Loading Stable Diffusion pipeline...") |
233 | 211 | ||
234 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 212 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
@@ -262,13 +240,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtyp | |||
262 | tokenizer=tokenizer, | 240 | tokenizer=tokenizer, |
263 | scheduler=scheduler, | 241 | scheduler=scheduler, |
264 | ) | 242 | ) |
265 | pipeline.aesthetic_gradient_iters = 30 | 243 | pipeline.aesthetic_gradient_iters = 20 |
266 | pipeline.to("cuda") | 244 | pipeline.to("cuda") |
267 | 245 | ||
268 | print("Pipeline loaded.") | 246 | print("Pipeline loaded.") |
269 | 247 | ||
270 | load_embeddings_ag(pipeline, ag_embeddings_dir) | ||
271 | |||
272 | return pipeline | 248 | return pipeline |
273 | 249 | ||
274 | 250 | ||
@@ -288,7 +264,7 @@ def generate(output_dir, pipeline, args): | |||
288 | else: | 264 | else: |
289 | init_image = None | 265 | init_image = None |
290 | 266 | ||
291 | with torch.autocast("cuda"): | 267 | with torch.autocast("cuda"), torch.inference_mode(): |
292 | for i in range(args.batch_num): | 268 | for i in range(args.batch_num): |
293 | pipeline.set_progress_bar_config( | 269 | pipeline.set_progress_bar_config( |
294 | desc=f"Batch {i + 1} of {args.batch_num}", | 270 | desc=f"Batch {i + 1} of {args.batch_num}", |
@@ -366,7 +342,7 @@ def main(): | |||
366 | output_dir = Path(args.output_dir) | 342 | output_dir = Path(args.output_dir) |
367 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 343 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
368 | 344 | ||
369 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) | 345 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) |
370 | cmd_parser = create_cmd_parser() | 346 | cmd_parser = create_cmd_parser() |
371 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 347 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
372 | cmd_prompt.cmdloop() | 348 | cmd_prompt.cmdloop() |