summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
committerVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
commite0b686b475885f0c8480f7173eaa7359adf17e27 (patch)
tree6ad882f152e63801d31230466e4d6468e7ada697 /infer.py
parentCode simplifications, avoid autocast (diff)
downloadtextual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/infer.py b/infer.py
index f566114..ae0b4da 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings 27from common import load_text_embeddings, load_config
28 28
29 29
30torch.backends.cuda.matmul.allow_tf32 = True 30torch.backends.cuda.matmul.allow_tf32 = True
@@ -46,8 +46,8 @@ default_cmds = {
46 "negative_prompt": None, 46 "negative_prompt": None,
47 "image": None, 47 "image": None,
48 "image_noise": .7, 48 "image_noise": .7,
49 "width": 512, 49 "width": 768,
50 "height": 512, 50 "height": 768,
51 "batch_size": 1, 51 "batch_size": 1,
52 "batch_num": 1, 52 "batch_num": 1,
53 "steps": 30, 53 "steps": 30,
@@ -163,9 +163,8 @@ def run_parser(parser, defaults, input=None):
163 conf_args = argparse.Namespace() 163 conf_args = argparse.Namespace()
164 164
165 if args.config is not None: 165 if args.config is not None:
166 with open(args.config, 'rt') as f: 166 args = load_config(args.config)
167 conf_args = parser.parse_known_args( 167 args = parser.parse_args(namespace=argparse.Namespace(**args))
168 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
169 168
170 res = defaults.copy() 169 res = defaults.copy()
171 for dict in [vars(conf_args), vars(args)]: 170 for dict in [vars(conf_args), vars(args)]: