diff options
author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
tree | 6ad882f152e63801d31230466e4d6468e7ada697 /train_lora.py | |
parent | Code simplifications, avoid autocast (diff) | |
download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip |
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/train_lora.py b/train_lora.py index 34e1008..ffca304 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
@@ -317,9 +317,8 @@ def parse_args(): | |||
317 | 317 | ||
318 | args = parser.parse_args() | 318 | args = parser.parse_args() |
319 | if args.config is not None: | 319 | if args.config is not None: |
320 | with open(args.config, 'rt') as f: | 320 | args = load_config(args.config) |
321 | args = parser.parse_args( | 321 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
322 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
323 | 322 | ||
324 | if args.train_data_file is None: | 323 | if args.train_data_file is None: |
325 | raise ValueError("You must specify --train_data_file") | 324 | raise ValueError("You must specify --train_data_file") |