From 21d70916f66e74a87c631a06b70774954b085b48 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 14:14:00 +0200 Subject: Fix --- train_lora.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index daf1f6c..476efcf 100644 --- a/train_lora.py +++ b/train_lora.py @@ -548,15 +548,18 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if args.initializer_tokens is None: + args.initializer_tokens = [] + + if args.placeholder_tokens is None: + args.placeholder_tokens = [] + if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if len(args.initializer_tokens) == 0: - raise ValueError("You must specify --initializer_tokens") - if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] @@ -884,7 +887,7 @@ def main(): num_pti_epochs = math.ceil( args.num_pti_steps / len(pti_datamodule.train_dataset) ) * args.pti_gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) + pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) pti_optimizer = create_optimizer( [ @@ -915,7 +918,7 @@ def main(): # -- sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, + sample_frequency=math.inf, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, use_emb_decay=args.use_emb_decay, -- cgit v1.2.3-70-g09d2