summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index f74a438..f8dccae 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from peft import LoraConfig, LoraModel 15from peft import LoraConfig, LoraModel
16from slugify import slugify 16from slugify import slugify
17import transformers
17 18
18from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
@@ -317,7 +318,7 @@ def parse_args():
317 "--optimizer", 318 "--optimizer",
318 type=str, 319 type=str,
319 default="dadan", 320 default="dadan",
320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 321 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
321 ) 322 )
322 parser.add_argument( 323 parser.add_argument(
323 "--dadaptation_d0", 324 "--dadaptation_d0",
@@ -567,6 +568,19 @@ def main():
567 eps=args.adam_epsilon, 568 eps=args.adam_epsilon,
568 amsgrad=args.adam_amsgrad, 569 amsgrad=args.adam_amsgrad,
569 ) 570 )
571 elif args.optimizer == 'adafactor':
572 create_optimizer = partial(
573 transformers.optimization.Adafactor,
574 beta1=args.adam_beta1,
575 weight_decay=args.adam_weight_decay,
576 scale_parameter=True,
577 relative_step=True,
578 warmup_init=True,
579 )
580
581 args.lr_scheduler = "adafactor"
582 args.lr_min_lr = args.learning_rate
583 args.learning_rate = None
570 elif args.optimizer == 'dadam': 584 elif args.optimizer == 'dadam':
571 try: 585 try:
572 import dadaptation 586 import dadaptation