diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
| commit | b42e7fbc29fd8045a2b932eb8ae76587f51f7513 (patch) | |
| tree | 85321e605cd8e183a0b9e05efcc4282921e667e0 /infer.py | |
| parent | Simplified multi-vector embedding code (diff) | |
| download | textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.gz textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.bz2 textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.zip | |
Bugfixes for multi-vector token handling
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 13 |
1 files changed, 9 insertions, 4 deletions
| @@ -7,6 +7,8 @@ import cmd | |||
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | import traceback | ||
| 11 | |||
| 10 | from PIL import Image | 12 | from PIL import Image |
| 11 | from slugify import slugify | 13 | from slugify import slugify |
| 12 | from diffusers import ( | 14 | from diffusers import ( |
| @@ -165,8 +167,8 @@ def run_parser(parser, defaults, input=None): | |||
| 165 | conf_args = argparse.Namespace() | 167 | conf_args = argparse.Namespace() |
| 166 | 168 | ||
| 167 | if args.config is not None: | 169 | if args.config is not None: |
| 168 | args = load_config(args.config) | 170 | conf_args = load_config(args.config) |
| 169 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | 171 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] |
| 170 | 172 | ||
| 171 | res = defaults.copy() | 173 | res = defaults.copy() |
| 172 | for dict in [vars(conf_args), vars(args)]: | 174 | for dict in [vars(conf_args), vars(args)]: |
| @@ -295,6 +297,7 @@ class CmdParse(cmd.Cmd): | |||
| 295 | elements = shlex.split(line) | 297 | elements = shlex.split(line) |
| 296 | except ValueError as e: | 298 | except ValueError as e: |
| 297 | print(str(e)) | 299 | print(str(e)) |
| 300 | return | ||
| 298 | 301 | ||
| 299 | if elements[0] == 'q': | 302 | if elements[0] == 'q': |
| 300 | return True | 303 | return True |
| @@ -306,9 +309,11 @@ class CmdParse(cmd.Cmd): | |||
| 306 | print('Try again with a prompt!') | 309 | print('Try again with a prompt!') |
| 307 | return | 310 | return |
| 308 | except SystemExit: | 311 | except SystemExit: |
| 312 | traceback.print_exc() | ||
| 309 | self.parser.print_help() | 313 | self.parser.print_help() |
| 314 | return | ||
| 310 | except Exception as e: | 315 | except Exception as e: |
| 311 | print(e) | 316 | traceback.print_exc() |
| 312 | return | 317 | return |
| 313 | 318 | ||
| 314 | try: | 319 | try: |
| @@ -316,7 +321,7 @@ class CmdParse(cmd.Cmd): | |||
| 316 | except KeyboardInterrupt: | 321 | except KeyboardInterrupt: |
| 317 | print('Generation cancelled.') | 322 | print('Generation cancelled.') |
| 318 | except Exception as e: | 323 | except Exception as e: |
| 319 | print(e) | 324 | traceback.print_exc() |
| 320 | return | 325 | return |
| 321 | 326 | ||
| 322 | def do_exit(self, line): | 327 | def do_exit(self, line): |
