summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
committerVolpeon <git@volpeon.ink>2022-10-02 20:57:43 +0200
commit64c594869135354a38353551bd58a93e15bd5b85 (patch)
tree2bcc085a396824f78e58c90b1f6e9553c7f5c8c1 /infer.py
parentFix img2img (diff)
downloadtextual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.gz
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.bz2
textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.zip
Small performance improvements
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/infer.py b/infer.py
index f2c380f..b15b17f 100644
--- a/infer.py
+++ b/infer.py
@@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler
19default_args = { 19default_args = {
20 "model": None, 20 "model": None,
21 "scheduler": "euler_a", 21 "scheduler": "euler_a",
22 "precision": "bf16",
22 "output_dir": "output/inference", 23 "output_dir": "output/inference",
23 "config": None, 24 "config": None,
24} 25}
@@ -28,7 +29,7 @@ default_cmds = {
28 "prompt": None, 29 "prompt": None,
29 "negative_prompt": None, 30 "negative_prompt": None,
30 "image": None, 31 "image": None,
31 "image_strength": .3, 32 "image_noise": .7,
32 "width": 512, 33 "width": 512,
33 "height": 512, 34 "height": 512,
34 "batch_size": 1, 35 "batch_size": 1,
@@ -63,6 +64,11 @@ def create_args_parser():
63 choices=["plms", "ddim", "klms", "euler_a"], 64 choices=["plms", "ddim", "klms", "euler_a"],
64 ) 65 )
65 parser.add_argument( 66 parser.add_argument(
67 "--precision",
68 type=str,
69 choices=["fp32", "fp16", "bf16"],
70 )
71 parser.add_argument(
66 "--output_dir", 72 "--output_dir",
67 type=str, 73 type=str,
68 ) 74 )
@@ -91,7 +97,7 @@ def create_cmd_parser():
91 type=str, 97 type=str,
92 ) 98 )
93 parser.add_argument( 99 parser.add_argument(
94 "--image_strength", 100 "--image_noise",
95 type=float, 101 type=float,
96 ) 102 )
97 parser.add_argument( 103 parser.add_argument(
@@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}):
153 json.dump(info, f, indent=4) 159 json.dump(info, f, indent=4)
154 160
155 161
156def create_pipeline(model, scheduler, dtype=torch.bfloat16): 162def create_pipeline(model, scheduler, dtype):
157 print("Loading Stable Diffusion pipeline...") 163 print("Loading Stable Diffusion pipeline...")
158 164
159 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) 165 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype)
@@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args):
225 guidance_scale=args.guidance_scale, 231 guidance_scale=args.guidance_scale,
226 generator=generator, 232 generator=generator,
227 latents=init_image, 233 latents=init_image,
228 strength=args.image_strength, 234 strength=args.image_noise,
229 ).images 235 ).images
230 236
231 for j, image in enumerate(images): 237 for j, image in enumerate(images):
@@ -279,9 +285,11 @@ def main():
279 285
280 args_parser = create_args_parser() 286 args_parser = create_args_parser()
281 args = run_parser(args_parser, default_args) 287 args = run_parser(args_parser, default_args)
288
282 output_dir = Path(args.output_dir) 289 output_dir = Path(args.output_dir)
290 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
283 291
284 pipeline = create_pipeline(args.model, args.scheduler) 292 pipeline = create_pipeline(args.model, args.scheduler, dtype)
285 cmd_parser = create_cmd_parser() 293 cmd_parser = create_cmd_parser()
286 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 294 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
287 cmd_prompt.cmdloop() 295 cmd_prompt.cmdloop()