diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index daf1f6c..476efcf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -548,15 +548,18 @@ def parse_args(): | |||
548 | if args.project is None: | 548 | if args.project is None: |
549 | raise ValueError("You must specify --project") | 549 | raise ValueError("You must specify --project") |
550 | 550 | ||
551 | if args.initializer_tokens is None: | ||
552 | args.initializer_tokens = [] | ||
553 | |||
554 | if args.placeholder_tokens is None: | ||
555 | args.placeholder_tokens = [] | ||
556 | |||
551 | if isinstance(args.placeholder_tokens, str): | 557 | if isinstance(args.placeholder_tokens, str): |
552 | args.placeholder_tokens = [args.placeholder_tokens] | 558 | args.placeholder_tokens = [args.placeholder_tokens] |
553 | 559 | ||
554 | if isinstance(args.initializer_tokens, str): | 560 | if isinstance(args.initializer_tokens, str): |
555 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 561 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
556 | 562 | ||
557 | if len(args.initializer_tokens) == 0: | ||
558 | raise ValueError("You must specify --initializer_tokens") | ||
559 | |||
560 | if len(args.placeholder_tokens) == 0: | 563 | if len(args.placeholder_tokens) == 0: |
561 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 564 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
562 | 565 | ||
@@ -884,7 +887,7 @@ def main(): | |||
884 | num_pti_epochs = math.ceil( | 887 | num_pti_epochs = math.ceil( |
885 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 888 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
886 | ) * args.pti_gradient_accumulation_steps | 889 | ) * args.pti_gradient_accumulation_steps |
887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 890 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) |
888 | 891 | ||
889 | pti_optimizer = create_optimizer( | 892 | pti_optimizer = create_optimizer( |
890 | [ | 893 | [ |
@@ -915,7 +918,7 @@ def main(): | |||
915 | # -- | 918 | # -- |
916 | sample_output_dir=pti_sample_output_dir, | 919 | sample_output_dir=pti_sample_output_dir, |
917 | checkpoint_output_dir=pti_checkpoint_output_dir, | 920 | checkpoint_output_dir=pti_checkpoint_output_dir, |
918 | sample_frequency=pti_sample_frequency, | 921 | sample_frequency=math.inf, |
919 | placeholder_tokens=args.placeholder_tokens, | 922 | placeholder_tokens=args.placeholder_tokens, |
920 | placeholder_token_ids=placeholder_token_ids, | 923 | placeholder_token_ids=placeholder_token_ids, |
921 | use_emb_decay=args.use_emb_decay, | 924 | use_emb_decay=args.use_emb_decay, |