diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -14,6 +14,7 @@ from accelerate.logging import get_logger | |||
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
| 16 | from slugify import slugify | 16 | from slugify import slugify |
| 17 | import transformers | ||
| 17 | 18 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -317,7 +318,7 @@ def parse_args(): | |||
| 317 | "--optimizer", | 318 | "--optimizer", |
| 318 | type=str, | 319 | type=str, |
| 319 | default="dadan", | 320 | default="dadan", |
| 320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 321 | ) | 322 | ) |
| 322 | parser.add_argument( | 323 | parser.add_argument( |
| 323 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
| @@ -567,6 +568,19 @@ def main(): | |||
| 567 | eps=args.adam_epsilon, | 568 | eps=args.adam_epsilon, |
| 568 | amsgrad=args.adam_amsgrad, | 569 | amsgrad=args.adam_amsgrad, |
| 569 | ) | 570 | ) |
| 571 | elif args.optimizer == 'adafactor': | ||
| 572 | create_optimizer = partial( | ||
| 573 | transformers.optimization.Adafactor, | ||
| 574 | beta1=args.adam_beta1, | ||
| 575 | weight_decay=args.adam_weight_decay, | ||
| 576 | scale_parameter=True, | ||
| 577 | relative_step=True, | ||
| 578 | warmup_init=True, | ||
| 579 | ) | ||
| 580 | |||
| 581 | args.lr_scheduler = "adafactor" | ||
| 582 | args.lr_min_lr = args.learning_rate | ||
| 583 | args.learning_rate = None | ||
| 570 | elif args.optimizer == 'dadam': | 584 | elif args.optimizer == 'dadam': |
| 571 | try: | 585 | try: |
| 572 | import dadaptation | 586 | import dadaptation |
