diff options
author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
tree | 6ad882f152e63801d31230466e4d6468e7ada697 /train_dreambooth.py | |
parent | Code simplifications, avoid autocast (diff) | |
download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip |
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c765ec..08bc9e0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
@@ -355,9 +355,8 @@ def parse_args(): | |||
355 | 355 | ||
356 | args = parser.parse_args() | 356 | args = parser.parse_args() |
357 | if args.config is not None: | 357 | if args.config is not None: |
358 | with open(args.config, 'rt') as f: | 358 | args = load_config(args.config) |
359 | args = parser.parse_args( | 359 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
360 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
361 | 360 | ||
362 | if args.train_data_file is None: | 361 | if args.train_data_file is None: |
363 | raise ValueError("You must specify --train_data_file") | 362 | raise ValueError("You must specify --train_data_file") |