diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -12,6 +12,7 @@ from accelerate import Accelerator | |||
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | import transformers | ||
15 | 16 | ||
16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
@@ -75,6 +76,12 @@ def parse_args(): | |||
75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
76 | ) | 77 | ) |
77 | parser.add_argument( | 78 | parser.add_argument( |
79 | "--initializer_noise", | ||
80 | type=float, | ||
81 | default=0, | ||
82 | help="Noise to apply to the initializer word" | ||
83 | ) | ||
84 | parser.add_argument( | ||
78 | "--alias_tokens", | 85 | "--alias_tokens", |
79 | type=str, | 86 | type=str, |
80 | nargs='*', | 87 | nargs='*', |
@@ -323,7 +330,7 @@ def parse_args(): | |||
323 | "--optimizer", | 330 | "--optimizer", |
324 | type=str, | 331 | type=str, |
325 | default="dadan", | 332 | default="dadan", |
326 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
327 | ) | 334 | ) |
328 | parser.add_argument( | 335 | parser.add_argument( |
329 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
@@ -659,6 +666,19 @@ def main(): | |||
659 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
660 | amsgrad=args.adam_amsgrad, | 667 | amsgrad=args.adam_amsgrad, |
661 | ) | 668 | ) |
669 | elif args.optimizer == 'adafactor': | ||
670 | create_optimizer = partial( | ||
671 | transformers.optimization.Adafactor, | ||
672 | beta1=args.adam_beta1, | ||
673 | weight_decay=args.adam_weight_decay, | ||
674 | scale_parameter=True, | ||
675 | relative_step=True, | ||
676 | warmup_init=True, | ||
677 | ) | ||
678 | |||
679 | args.lr_scheduler = "adafactor" | ||
680 | args.lr_min_lr = args.learning_rate | ||
681 | args.learning_rate = None | ||
662 | elif args.optimizer == 'dadam': | 682 | elif args.optimizer == 'dadam': |
663 | try: | 683 | try: |
664 | import dadaptation | 684 | import dadaptation |
@@ -739,7 +759,8 @@ def main(): | |||
739 | embeddings=embeddings, | 759 | embeddings=embeddings, |
740 | placeholder_tokens=placeholder_tokens, | 760 | placeholder_tokens=placeholder_tokens, |
741 | initializer_tokens=initializer_tokens, | 761 | initializer_tokens=initializer_tokens, |
742 | num_vectors=num_vectors | 762 | num_vectors=num_vectors, |
763 | initializer_noise=args.initializer_noise, | ||
743 | ) | 764 | ) |
744 | 765 | ||
745 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 766 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |