From e8d17379efa491a019c3a5fc2633e55ab87e3432 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 31 Mar 2023 12:48:18 +0200 Subject: Support Dadaptation d0, adjust sample freq when steps instead of epochs are used --- train_lora.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 9975462..7b54ef8 100644 --- a/train_lora.py +++ b/train_lora.py @@ -313,6 +313,12 @@ def parse_args(): default="dadan", help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) + parser.add_argument( + "--dadaptation_d0", + type=float, + default=1e-6, + help="The d0 parameter for Dadaptation optimizers." + ) parser.add_argument( "--adam_beta1", type=float, @@ -567,6 +573,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -580,6 +587,7 @@ def main(): dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -628,10 +636,9 @@ def main(): datamodule.setup() num_train_epochs = args.num_train_epochs - if num_train_epochs is None: - num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size - num_train_epochs = math.ceil(args.num_train_steps / num_images) + num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) optimizer = create_optimizer( itertools.chain( @@ -667,7 +674,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=args.sample_frequency, + sample_frequency=sample_frequency, offset_noise_strength=args.offset_noise_strength, # -- tokenizer=tokenizer, -- cgit v1.2.3-70-g09d2