diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
| commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
| tree | 6ad882f152e63801d31230466e4d6468e7ada697 /train_dreambooth.py | |
| parent | Code simplifications, avoid autocast (diff) | |
| download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip | |
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c765ec..08bc9e0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| @@ -355,9 +355,8 @@ def parse_args(): | |||
| 355 | 355 | ||
| 356 | args = parser.parse_args() | 356 | args = parser.parse_args() |
| 357 | if args.config is not None: | 357 | if args.config is not None: |
| 358 | with open(args.config, 'rt') as f: | 358 | args = load_config(args.config) |
| 359 | args = parser.parse_args( | 359 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
| 360 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 361 | 360 | ||
| 362 | if args.train_data_file is None: | 361 | if args.train_data_file is None: |
| 363 | raise ValueError("You must specify --train_data_file") | 362 | raise ValueError("You must specify --train_data_file") |
