summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py
index 12e3644..6dc07dd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -86,7 +86,7 @@ def parse_args():
86 help="Number of vectors per embedding." 86 help="Number of vectors per embedding."
87 ) 87 )
88 parser.add_argument( 88 parser.add_argument(
89 "--simultaneous", 89 "--sequential",
90 action="store_true", 90 action="store_true",
91 ) 91 )
92 parser.add_argument( 92 parser.add_argument(
@@ -293,7 +293,7 @@ def parse_args():
293 "--optimizer", 293 "--optimizer",
294 type=str, 294 type=str,
295 default="adam", 295 default="adam",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]' 296 help='Optimizer to use ["adam", "adam8bit"]'
297 ) 297 )
298 parser.add_argument( 298 parser.add_argument(
299 "--adam_beta1", 299 "--adam_beta1",
@@ -343,6 +343,11 @@ def parse_args():
343 help="How often to save a checkpoint and sample image (in epochs)", 343 help="How often to save a checkpoint and sample image (in epochs)",
344 ) 344 )
345 parser.add_argument( 345 parser.add_argument(
346 "--no_milestone_checkpoints",
347 action='store_true',
348 help="If checkpoints are saved on maximum accuracy",
349 )
350 parser.add_argument(
346 "--sample_frequency", 351 "--sample_frequency",
347 type=int, 352 type=int,
348 default=1, 353 default=1,
@@ -480,7 +485,7 @@ def parse_args():
480 if len(args.placeholder_tokens) != len(args.num_vectors): 485 if len(args.placeholder_tokens) != len(args.num_vectors):
481 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 486 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
482 487
483 if not args.simultaneous: 488 if args.sequential:
484 if isinstance(args.train_data_template, str): 489 if isinstance(args.train_data_template, str):
485 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 490 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
486 491
@@ -586,13 +591,6 @@ def main():
586 eps=args.adam_epsilon, 591 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad, 592 amsgrad=args.adam_amsgrad,
588 ) 593 )
589 elif args.optimizer == 'lion':
590 try:
591 from lion_pytorch import Lion
592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594
595 create_optimizer = partial(Lion, use_triton=True)
596 else: 594 else:
597 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 595 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
598 596
@@ -615,6 +613,7 @@ def main():
615 num_train_epochs=args.num_train_epochs, 613 num_train_epochs=args.num_train_epochs,
616 sample_frequency=args.sample_frequency, 614 sample_frequency=args.sample_frequency,
617 checkpoint_frequency=args.checkpoint_frequency, 615 checkpoint_frequency=args.checkpoint_frequency,
616 milestone_checkpoints=not args.no_milestone_checkpoints,
618 global_step_offset=global_step_offset, 617 global_step_offset=global_step_offset,
619 # -- 618 # --
620 tokenizer=tokenizer, 619 tokenizer=tokenizer,
@@ -715,7 +714,7 @@ def main():
715 714
716 plot_metrics(metrics, metrics_output_file) 715 plot_metrics(metrics, metrics_output_file)
717 716
718 if args.simultaneous: 717 if not args.sequential:
719 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 718 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
720 else: 719 else:
721 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 720 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(