diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -12,6 +12,7 @@ from accelerate import Accelerator | |||
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | import transformers | ||
| 15 | 16 | ||
| 16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -75,6 +76,12 @@ def parse_args(): | |||
| 75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 79 | "--initializer_noise", | ||
| 80 | type=float, | ||
| 81 | default=0, | ||
| 82 | help="Noise to apply to the initializer word" | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 78 | "--alias_tokens", | 85 | "--alias_tokens", |
| 79 | type=str, | 86 | type=str, |
| 80 | nargs='*', | 87 | nargs='*', |
| @@ -323,7 +330,7 @@ def parse_args(): | |||
| 323 | "--optimizer", | 330 | "--optimizer", |
| 324 | type=str, | 331 | type=str, |
| 325 | default="dadan", | 332 | default="dadan", |
| 326 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 327 | ) | 334 | ) |
| 328 | parser.add_argument( | 335 | parser.add_argument( |
| 329 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
| @@ -659,6 +666,19 @@ def main(): | |||
| 659 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
| 660 | amsgrad=args.adam_amsgrad, | 667 | amsgrad=args.adam_amsgrad, |
| 661 | ) | 668 | ) |
| 669 | elif args.optimizer == 'adafactor': | ||
| 670 | create_optimizer = partial( | ||
| 671 | transformers.optimization.Adafactor, | ||
| 672 | beta1=args.adam_beta1, | ||
| 673 | weight_decay=args.adam_weight_decay, | ||
| 674 | scale_parameter=True, | ||
| 675 | relative_step=True, | ||
| 676 | warmup_init=True, | ||
| 677 | ) | ||
| 678 | |||
| 679 | args.lr_scheduler = "adafactor" | ||
| 680 | args.lr_min_lr = args.learning_rate | ||
| 681 | args.learning_rate = None | ||
| 662 | elif args.optimizer == 'dadam': | 682 | elif args.optimizer == 'dadam': |
| 663 | try: | 683 | try: |
| 664 | import dadaptation | 684 | import dadaptation |
| @@ -739,7 +759,8 @@ def main(): | |||
| 739 | embeddings=embeddings, | 759 | embeddings=embeddings, |
| 740 | placeholder_tokens=placeholder_tokens, | 760 | placeholder_tokens=placeholder_tokens, |
| 741 | initializer_tokens=initializer_tokens, | 761 | initializer_tokens=initializer_tokens, |
| 742 | num_vectors=num_vectors | 762 | num_vectors=num_vectors, |
| 763 | initializer_noise=args.initializer_noise, | ||
| 743 | ) | 764 | ) |
| 744 | 765 | ||
| 745 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 766 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
