diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -86,7 +86,7 @@ def parse_args(): | |||
| 86 | help="Number of vectors per embedding." | 86 | help="Number of vectors per embedding." |
| 87 | ) | 87 | ) |
| 88 | parser.add_argument( | 88 | parser.add_argument( |
| 89 | "--simultaneous", | 89 | "--sequential", |
| 90 | action="store_true", | 90 | action="store_true", |
| 91 | ) | 91 | ) |
| 92 | parser.add_argument( | 92 | parser.add_argument( |
| @@ -293,7 +293,7 @@ def parse_args(): | |||
| 293 | "--optimizer", | 293 | "--optimizer", |
| 294 | type=str, | 294 | type=str, |
| 295 | default="adam", | 295 | default="adam", |
| 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit"]' |
| 297 | ) | 297 | ) |
| 298 | parser.add_argument( | 298 | parser.add_argument( |
| 299 | "--adam_beta1", | 299 | "--adam_beta1", |
| @@ -343,6 +343,11 @@ def parse_args(): | |||
| 343 | help="How often to save a checkpoint and sample image (in epochs)", | 343 | help="How often to save a checkpoint and sample image (in epochs)", |
| 344 | ) | 344 | ) |
| 345 | parser.add_argument( | 345 | parser.add_argument( |
| 346 | "--no_milestone_checkpoints", | ||
| 347 | action='store_true', | ||
| 348 | help="If checkpoints are saved on maximum accuracy", | ||
| 349 | ) | ||
| 350 | parser.add_argument( | ||
| 346 | "--sample_frequency", | 351 | "--sample_frequency", |
| 347 | type=int, | 352 | type=int, |
| 348 | default=1, | 353 | default=1, |
| @@ -480,7 +485,7 @@ def parse_args(): | |||
| 480 | if len(args.placeholder_tokens) != len(args.num_vectors): | 485 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 481 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 482 | 487 | ||
| 483 | if not args.simultaneous: | 488 | if args.sequential: |
| 484 | if isinstance(args.train_data_template, str): | 489 | if isinstance(args.train_data_template, str): |
| 485 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 490 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
| 486 | 491 | ||
| @@ -586,13 +591,6 @@ def main(): | |||
| 586 | eps=args.adam_epsilon, | 591 | eps=args.adam_epsilon, |
| 587 | amsgrad=args.adam_amsgrad, | 592 | amsgrad=args.adam_amsgrad, |
| 588 | ) | 593 | ) |
| 589 | elif args.optimizer == 'lion': | ||
| 590 | try: | ||
| 591 | from lion_pytorch import Lion | ||
| 592 | except ImportError: | ||
| 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 594 | |||
| 595 | create_optimizer = partial(Lion, use_triton=True) | ||
| 596 | else: | 594 | else: |
| 597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 595 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 598 | 596 | ||
| @@ -615,6 +613,7 @@ def main(): | |||
| 615 | num_train_epochs=args.num_train_epochs, | 613 | num_train_epochs=args.num_train_epochs, |
| 616 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
| 617 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
| 616 | milestone_checkpoints=not args.no_milestone_checkpoints, | ||
| 618 | global_step_offset=global_step_offset, | 617 | global_step_offset=global_step_offset, |
| 619 | # -- | 618 | # -- |
| 620 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
| @@ -715,7 +714,7 @@ def main(): | |||
| 715 | 714 | ||
| 716 | plot_metrics(metrics, metrics_output_file) | 715 | plot_metrics(metrics, metrics_output_file) |
| 717 | 716 | ||
| 718 | if args.simultaneous: | 717 | if not args.sequential: |
| 719 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 718 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 720 | else: | 719 | else: |
| 721 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 720 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
