diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-02 20:57:43 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-02 20:57:43 +0200 |
| commit | 64c594869135354a38353551bd58a93e15bd5b85 (patch) | |
| tree | 2bcc085a396824f78e58c90b1f6e9553c7f5c8c1 /infer.py | |
| parent | Fix img2img (diff) | |
| download | textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.gz textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.tar.bz2 textual-inversion-diff-64c594869135354a38353551bd58a93e15bd5b85.zip | |
Small performance improvements
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 18 |
1 files changed, 13 insertions, 5 deletions
| @@ -19,6 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
| 19 | default_args = { | 19 | default_args = { |
| 20 | "model": None, | 20 | "model": None, |
| 21 | "scheduler": "euler_a", | 21 | "scheduler": "euler_a", |
| 22 | "precision": "bf16", | ||
| 22 | "output_dir": "output/inference", | 23 | "output_dir": "output/inference", |
| 23 | "config": None, | 24 | "config": None, |
| 24 | } | 25 | } |
| @@ -28,7 +29,7 @@ default_cmds = { | |||
| 28 | "prompt": None, | 29 | "prompt": None, |
| 29 | "negative_prompt": None, | 30 | "negative_prompt": None, |
| 30 | "image": None, | 31 | "image": None, |
| 31 | "image_strength": .3, | 32 | "image_noise": .7, |
| 32 | "width": 512, | 33 | "width": 512, |
| 33 | "height": 512, | 34 | "height": 512, |
| 34 | "batch_size": 1, | 35 | "batch_size": 1, |
| @@ -63,6 +64,11 @@ def create_args_parser(): | |||
| 63 | choices=["plms", "ddim", "klms", "euler_a"], | 64 | choices=["plms", "ddim", "klms", "euler_a"], |
| 64 | ) | 65 | ) |
| 65 | parser.add_argument( | 66 | parser.add_argument( |
| 67 | "--precision", | ||
| 68 | type=str, | ||
| 69 | choices=["fp32", "fp16", "bf16"], | ||
| 70 | ) | ||
| 71 | parser.add_argument( | ||
| 66 | "--output_dir", | 72 | "--output_dir", |
| 67 | type=str, | 73 | type=str, |
| 68 | ) | 74 | ) |
| @@ -91,7 +97,7 @@ def create_cmd_parser(): | |||
| 91 | type=str, | 97 | type=str, |
| 92 | ) | 98 | ) |
| 93 | parser.add_argument( | 99 | parser.add_argument( |
| 94 | "--image_strength", | 100 | "--image_noise", |
| 95 | type=float, | 101 | type=float, |
| 96 | ) | 102 | ) |
| 97 | parser.add_argument( | 103 | parser.add_argument( |
| @@ -153,7 +159,7 @@ def save_args(basepath, args, extra={}): | |||
| 153 | json.dump(info, f, indent=4) | 159 | json.dump(info, f, indent=4) |
| 154 | 160 | ||
| 155 | 161 | ||
| 156 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): | 162 | def create_pipeline(model, scheduler, dtype): |
| 157 | print("Loading Stable Diffusion pipeline...") | 163 | print("Loading Stable Diffusion pipeline...") |
| 158 | 164 | ||
| 159 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 165 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
| @@ -225,7 +231,7 @@ def generate(output_dir, pipeline, args): | |||
| 225 | guidance_scale=args.guidance_scale, | 231 | guidance_scale=args.guidance_scale, |
| 226 | generator=generator, | 232 | generator=generator, |
| 227 | latents=init_image, | 233 | latents=init_image, |
| 228 | strength=args.image_strength, | 234 | strength=args.image_noise, |
| 229 | ).images | 235 | ).images |
| 230 | 236 | ||
| 231 | for j, image in enumerate(images): | 237 | for j, image in enumerate(images): |
| @@ -279,9 +285,11 @@ def main(): | |||
| 279 | 285 | ||
| 280 | args_parser = create_args_parser() | 286 | args_parser = create_args_parser() |
| 281 | args = run_parser(args_parser, default_args) | 287 | args = run_parser(args_parser, default_args) |
| 288 | |||
| 282 | output_dir = Path(args.output_dir) | 289 | output_dir = Path(args.output_dir) |
| 290 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | ||
| 283 | 291 | ||
| 284 | pipeline = create_pipeline(args.model, args.scheduler) | 292 | pipeline = create_pipeline(args.model, args.scheduler, dtype) |
| 285 | cmd_parser = create_cmd_parser() | 293 | cmd_parser = create_cmd_parser() |
| 286 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 294 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 287 | cmd_prompt.cmdloop() | 295 | cmd_prompt.cmdloop() |
