From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- train_lora.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py @@ -14,6 +14,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -317,7 +318,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -567,6 +568,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation -- cgit v1.2.3-54-g00ecf