summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
committerVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
commite0b686b475885f0c8480f7173eaa7359adf17e27 (patch)
tree6ad882f152e63801d31230466e4d6468e7ada697 /train_dreambooth.py
parentCode simplifications, avoid autocast (diff)
downloadtextual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2c765ec..08bc9e0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
@@ -355,9 +355,8 @@ def parse_args():
355 355
356 args = parser.parse_args() 356 args = parser.parse_args()
357 if args.config is not None: 357 if args.config is not None:
358 with open(args.config, 'rt') as f: 358 args = load_config(args.config)
359 args = parser.parse_args( 359 args = parser.parse_args(namespace=argparse.Namespace(**args))
360 namespace=argparse.Namespace(**json.load(f)["args"]))
361 360
362 if args.train_data_file is None: 361 if args.train_data_file is None:
363 raise ValueError("You must specify --train_data_file") 362 raise ValueError("You must specify --train_data_file")