From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- train_ti.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py @@ -12,6 +12,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -74,6 +75,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word" + ) parser.add_argument( "--alias_tokens", type=str, @@ -323,7 +330,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -659,6 +666,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation @@ -739,7 +759,8 @@ def main(): embeddings=embeddings, placeholder_tokens=placeholder_tokens, initializer_tokens=initializer_tokens, - num_vectors=num_vectors + num_vectors=num_vectors, + initializer_noise=args.initializer_noise, ) stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) -- cgit v1.2.3-54-g00ecf