From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- train_dreambooth.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -13,6 +13,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -305,7 +306,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -535,6 +536,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation -- cgit v1.2.3-70-g09d2