summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index 2b07b21..42b4e2d 100644
--- a/infer.py
+++ b/infer.py
@@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True
39default_args = { 39default_args = {
40 "model": "stabilityai/stable-diffusion-2-1", 40 "model": "stabilityai/stable-diffusion-2-1",
41 "precision": "fp32", 41 "precision": "fp32",
42 "ti_embeddings_dir": "embeddings", 42 "ti_embeddings_dir": "embeddings_ti",
43 "lora_embeddings_dir": "embeddings_lora",
43 "output_dir": "output/inference", 44 "output_dir": "output/inference",
44 "config": None, 45 "config": None,
45} 46}
@@ -60,6 +61,7 @@ default_cmds = {
60 "batch_num": 1, 61 "batch_num": 1,
61 "steps": 30, 62 "steps": 30,
62 "guidance_scale": 7.0, 63 "guidance_scale": 7.0,
64 "lora_scale": 0.5,
63 "seed": None, 65 "seed": None,
64 "config": None, 66 "config": None,
65} 67}
@@ -92,6 +94,10 @@ def create_args_parser():
92 type=str, 94 type=str,
93 ) 95 )
94 parser.add_argument( 96 parser.add_argument(
97 "--lora_embeddings_dir",
98 type=str,
99 )
100 parser.add_argument(
95 "--output_dir", 101 "--output_dir",
96 type=str, 102 type=str,
97 ) 103 )
@@ -169,6 +175,10 @@ def create_cmd_parser():
169 type=float, 175 type=float,
170 ) 176 )
171 parser.add_argument( 177 parser.add_argument(
178 "--lora_scale",
179 type=float,
180 )
181 parser.add_argument(
172 "--seed", 182 "--seed",
173 type=int, 183 type=int,
174 ) 184 )
@@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args):
315 generator=generator, 325 generator=generator,
316 image=init_image, 326 image=init_image,
317 strength=args.image_noise, 327 strength=args.image_noise,
328 cross_attention_kwargs={"scale": args.lora_scale},
318 ).images 329 ).images
319 330
320 for j, image in enumerate(images): 331 for j, image in enumerate(images):
@@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd):
334 prompt = 'dream> ' 345 prompt = 'dream> '
335 commands = [] 346 commands = []
336 347
337 def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): 348 def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser):
338 super().__init__() 349 super().__init__()
339 350
340 self.output_dir = output_dir 351 self.output_dir = output_dir
341 self.ti_embeddings_dir = ti_embeddings_dir 352 self.ti_embeddings_dir = ti_embeddings_dir
353 self.lora_embeddings_dir = lora_embeddings_dir
342 self.pipeline = pipeline 354 self.pipeline = pipeline
343 self.parser = parser 355 self.parser = parser
344 356
@@ -394,9 +406,12 @@ def main():
394 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 406 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
395 407
396 pipeline = create_pipeline(args.model, dtype) 408 pipeline = create_pipeline(args.model, dtype)
409
397 load_embeddings(pipeline, args.ti_embeddings_dir) 410 load_embeddings(pipeline, args.ti_embeddings_dir)
411 pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
412
398 cmd_parser = create_cmd_parser() 413 cmd_parser = create_cmd_parser()
399 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) 414 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser)
400 cmd_prompt.cmdloop() 415 cmd_prompt.cmdloop()
401 416
402 417