summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-14 20:03:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-14 20:03:01 +0200
commit6a49074dce78615bce54777fb2be3bfd0dd8f780 (patch)
tree0f7dde5ea6b6343fb6e0a527e5ebb2940d418dce /infer.py
parentAdded support for Aesthetic Gradients (diff)
downloadtextual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.gz
textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.bz2
textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.zip
Removed aesthetic gradients; training improvements
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py32
1 files changed, 4 insertions, 28 deletions
diff --git a/infer.py b/infer.py
index 650c119..1a0baf5 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,6 @@ default_args = {
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "fp32", 25 "precision": "fp32",
26 "ti_embeddings_dir": "embeddings_ti", 26 "ti_embeddings_dir": "embeddings_ti",
27 "ag_embeddings_dir": "embeddings_ag",
28 "output_dir": "output/inference", 27 "output_dir": "output/inference",
29 "config": None, 28 "config": None,
30} 29}
@@ -78,10 +77,6 @@ def create_args_parser():
78 type=str, 77 type=str,
79 ) 78 )
80 parser.add_argument( 79 parser.add_argument(
81 "--ag_embeddings_dir",
82 type=str,
83 )
84 parser.add_argument(
85 "--output_dir", 80 "--output_dir",
86 type=str, 81 type=str,
87 ) 82 )
@@ -211,24 +206,7 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
211 print(f"Loaded {placeholder_token}") 206 print(f"Loaded {placeholder_token}")
212 207
213 208
214def load_embeddings_ag(pipeline, embeddings_dir): 209def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
215 print(f"Loading Aesthetic Gradient embeddings")
216
217 embeddings_dir = Path(embeddings_dir)
218 embeddings_dir.mkdir(parents=True, exist_ok=True)
219
220 for file in embeddings_dir.iterdir():
221 if file.is_file():
222 placeholder_token = file.stem
223
224 data = torch.load(file, map_location="cpu")
225
226 pipeline.add_aesthetic_gradient_embedding(placeholder_token, data)
227
228 print(f"Loaded {placeholder_token}")
229
230
231def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype):
232 print("Loading Stable Diffusion pipeline...") 210 print("Loading Stable Diffusion pipeline...")
233 211
234 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 212 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -262,13 +240,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtyp
262 tokenizer=tokenizer, 240 tokenizer=tokenizer,
263 scheduler=scheduler, 241 scheduler=scheduler,
264 ) 242 )
265 pipeline.aesthetic_gradient_iters = 30 243 pipeline.aesthetic_gradient_iters = 20
266 pipeline.to("cuda") 244 pipeline.to("cuda")
267 245
268 print("Pipeline loaded.") 246 print("Pipeline loaded.")
269 247
270 load_embeddings_ag(pipeline, ag_embeddings_dir)
271
272 return pipeline 248 return pipeline
273 249
274 250
@@ -288,7 +264,7 @@ def generate(output_dir, pipeline, args):
288 else: 264 else:
289 init_image = None 265 init_image = None
290 266
291 with torch.autocast("cuda"): 267 with torch.autocast("cuda"), torch.inference_mode():
292 for i in range(args.batch_num): 268 for i in range(args.batch_num):
293 pipeline.set_progress_bar_config( 269 pipeline.set_progress_bar_config(
294 desc=f"Batch {i + 1} of {args.batch_num}", 270 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -366,7 +342,7 @@ def main():
366 output_dir = Path(args.output_dir) 342 output_dir = Path(args.output_dir)
367 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 343 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
368 344
369 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) 345 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype)
370 cmd_parser = create_cmd_parser() 346 cmd_parser = create_cmd_parser()
371 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 347 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
372 cmd_prompt.cmdloop() 348 cmd_prompt.cmdloop()