summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/train_lora.py b/train_lora.py
index 34e1008..ffca304 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.lora import LoraAttnProcessor 26from training.lora import LoraAttnProcessor
@@ -317,9 +317,8 @@ def parse_args():
317 317
318 args = parser.parse_args() 318 args = parser.parse_args()
319 if args.config is not None: 319 if args.config is not None:
320 with open(args.config, 'rt') as f: 320 args = load_config(args.config)
321 args = parser.parse_args( 321 args = parser.parse_args(namespace=argparse.Namespace(**args))
322 namespace=argparse.Namespace(**json.load(f)["args"]))
323 322
324 if args.train_data_file is None: 323 if args.train_data_file is None:
325 raise ValueError("You must specify --train_data_file") 324 raise ValueError("You must specify --train_data_file")