diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 18 |
1 files changed, 13 insertions, 5 deletions
@@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
19 | default_args = { | 19 | default_args = { |
20 | "model": None, | 20 | "model": None, |
21 | "scheduler": "euler_a", | 21 | "scheduler": "euler_a", |
22 | "precision": "bf16", | ||
22 | "output_dir": "output/inference", | 23 | "output_dir": "output/inference", |
23 | "config": None, | 24 | "config": None, |
24 | } | 25 | } |
@@ -28,7 +29,7 @@ default_cmds = { | |||
28 | "prompt": None, | 29 | "prompt": None, |
29 | "negative_prompt": None, | 30 | "negative_prompt": None, |
30 | "image": None, | 31 | "image": None, |
31 | "image_strength": .3, | 32 | "image_noise": .7, |
32 | "width": 512, | 33 | "width": 512, |
33 | "height": 512, | 34 | "height": 512, |
34 | "batch_size": 1, | 35 | "batch_size": 1, |
@@ -63,6 +64,11 @@ def create_args_parser(): | |||
63 | choices=["plms", "ddim", "klms", "euler_a"], | 64 | choices=["plms", "ddim", "klms", "euler_a"], |
64 | ) | 65 | ) |
65 | parser.add_argument( | 66 | parser.add_argument( |
67 | "--precision", | ||
68 | type=str, | ||
69 | choices=["fp32", "fp16", "bf16"], | ||
70 | ) | ||
71 | parser.add_argument( | ||
66 | "--output_dir", | 72 | "--output_dir", |
67 | type=str, | 73 | type=str, |
68 | ) | 74 | ) |
@@ -91,7 +97,7 @@ def create_cmd_parser(): | |||
91 | type=str, | 97 | type=str, |
92 | ) | 98 | ) |
93 | parser.add_argument( | 99 | parser.add_argument( |
94 | "--image_strength", | 100 | "--image_noise", |
95 | type=float, | 101 | type=float, |
96 | ) | 102 | ) |
97 | parser.add_argument( | 103 | parser.add_argument( |
@@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): | |||
153 | json.dump(info, f, indent=4) | 159 | json.dump(info, f, indent=4) |
154 | 160 | ||
155 | 161 | ||
156 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): | 162 | def create_pipeline(model, scheduler, dtype): |
157 | print("Loading Stable Diffusion pipeline...") | 163 | print("Loading Stable Diffusion pipeline...") |
158 | 164 | ||
159 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 165 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
@@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): | |||
225 | guidance_scale=args.guidance_scale, | 231 | guidance_scale=args.guidance_scale, |
226 | generator=generator, | 232 | generator=generator, |
227 | latents=init_image, | 233 | latents=init_image, |
228 | strength=args.image_strength, | 234 | strength=args.image_noise, |
229 | ).images | 235 | ).images |
230 | 236 | ||
231 | for j, image in enumerate(images): | 237 | for j, image in enumerate(images): |
@@ -279,9 +285,11 @@ def main(): | |||
279 | 285 | ||
280 | args_parser = create_args_parser() | 286 | args_parser = create_args_parser() |
281 | args = run_parser(args_parser, default_args) | 287 | args = run_parser(args_parser, default_args) |
288 | |||
282 | output_dir = Path(args.output_dir) | 289 | output_dir = Path(args.output_dir) |
290 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | ||
283 | 291 | ||
284 | pipeline = create_pipeline(args.model, args.scheduler) | 292 | pipeline = create_pipeline(args.model, args.scheduler, dtype) |
285 | cmd_parser = create_cmd_parser() | 293 | cmd_parser = create_cmd_parser() |
286 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 294 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
287 | cmd_prompt.cmdloop() | 295 | cmd_prompt.cmdloop() |