From e0b686b475885f0c8480f7173eaa7359adf17e27 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 26 Dec 2022 14:24:21 +0100 Subject: Set default dimensions to 768; add config inheritance --- train_dreambooth.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c765ec..08bc9e0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from common import load_text_embeddings +from common import load_text_embeddings, load_config from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule @@ -355,9 +355,8 @@ def parse_args(): args = parser.parse_args() if args.config is not None: - with open(args.config, 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) + args = load_config(args.config) + args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") -- cgit v1.2.3-54-g00ecf