summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 17:12:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 17:12:12 +0100
commitb42e7fbc29fd8045a2b932eb8ae76587f51f7513 (patch)
tree85321e605cd8e183a0b9e05efcc4282921e667e0 /infer.py
parentSimplified multi-vector embedding code (diff)
downloadtextual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.gz
textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.bz2
textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.zip
Bugfixes for multi-vector token handling
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/infer.py b/infer.py
index 4bcaff5..f88245a 100644
--- a/infer.py
+++ b/infer.py
@@ -7,6 +7,8 @@ import cmd
7from pathlib import Path 7from pathlib import Path
8import torch 8import torch
9import json 9import json
10import traceback
11
10from PIL import Image 12from PIL import Image
11from slugify import slugify 13from slugify import slugify
12from diffusers import ( 14from diffusers import (
@@ -165,8 +167,8 @@ def run_parser(parser, defaults, input=None):
165 conf_args = argparse.Namespace() 167 conf_args = argparse.Namespace()
166 168
167 if args.config is not None: 169 if args.config is not None:
168 args = load_config(args.config) 170 conf_args = load_config(args.config)
169 args = parser.parse_args(namespace=argparse.Namespace(**args)) 171 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0]
170 172
171 res = defaults.copy() 173 res = defaults.copy()
172 for dict in [vars(conf_args), vars(args)]: 174 for dict in [vars(conf_args), vars(args)]:
@@ -295,6 +297,7 @@ class CmdParse(cmd.Cmd):
295 elements = shlex.split(line) 297 elements = shlex.split(line)
296 except ValueError as e: 298 except ValueError as e:
297 print(str(e)) 299 print(str(e))
300 return
298 301
299 if elements[0] == 'q': 302 if elements[0] == 'q':
300 return True 303 return True
@@ -306,9 +309,11 @@ class CmdParse(cmd.Cmd):
306 print('Try again with a prompt!') 309 print('Try again with a prompt!')
307 return 310 return
308 except SystemExit: 311 except SystemExit:
312 traceback.print_exc()
309 self.parser.print_help() 313 self.parser.print_help()
314 return
310 except Exception as e: 315 except Exception as e:
311 print(e) 316 traceback.print_exc()
312 return 317 return
313 318
314 try: 319 try:
@@ -316,7 +321,7 @@ class CmdParse(cmd.Cmd):
316 except KeyboardInterrupt: 321 except KeyboardInterrupt:
317 print('Generation cancelled.') 322 print('Generation cancelled.')
318 except Exception as e: 323 except Exception as e:
319 print(e) 324 traceback.print_exc()
320 return 325 return
321 326
322 def do_exit(self, line): 327 def do_exit(self, line):