summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/train_lora.py b/train_lora.py
index 9975462..7b54ef8 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -314,6 +314,12 @@ def parse_args():
314 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 314 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
315 ) 315 )
316 parser.add_argument( 316 parser.add_argument(
317 "--dadaptation_d0",
318 type=float,
319 default=1e-6,
320 help="The d0 parameter for Dadaptation optimizers."
321 )
322 parser.add_argument(
317 "--adam_beta1", 323 "--adam_beta1",
318 type=float, 324 type=float,
319 default=0.9, 325 default=0.9,
@@ -567,6 +573,7 @@ def main():
567 weight_decay=args.adam_weight_decay, 573 weight_decay=args.adam_weight_decay,
568 eps=args.adam_epsilon, 574 eps=args.adam_epsilon,
569 decouple=True, 575 decouple=True,
576 d0=args.dadaptation_d0,
570 ) 577 )
571 578
572 args.learning_rate = 1.0 579 args.learning_rate = 1.0
@@ -580,6 +587,7 @@ def main():
580 dadaptation.DAdaptAdan, 587 dadaptation.DAdaptAdan,
581 weight_decay=args.adam_weight_decay, 588 weight_decay=args.adam_weight_decay,
582 eps=args.adam_epsilon, 589 eps=args.adam_epsilon,
590 d0=args.dadaptation_d0,
583 ) 591 )
584 592
585 args.learning_rate = 1.0 593 args.learning_rate = 1.0
@@ -628,10 +636,9 @@ def main():
628 datamodule.setup() 636 datamodule.setup()
629 637
630 num_train_epochs = args.num_train_epochs 638 num_train_epochs = args.num_train_epochs
631
632 if num_train_epochs is None: 639 if num_train_epochs is None:
633 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 640 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
634 num_train_epochs = math.ceil(args.num_train_steps / num_images) 641 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
635 642
636 optimizer = create_optimizer( 643 optimizer = create_optimizer(
637 itertools.chain( 644 itertools.chain(
@@ -667,7 +674,7 @@ def main():
667 lr_scheduler=lr_scheduler, 674 lr_scheduler=lr_scheduler,
668 num_train_epochs=num_train_epochs, 675 num_train_epochs=num_train_epochs,
669 gradient_accumulation_steps=args.gradient_accumulation_steps, 676 gradient_accumulation_steps=args.gradient_accumulation_steps,
670 sample_frequency=args.sample_frequency, 677 sample_frequency=sample_frequency,
671 offset_noise_strength=args.offset_noise_strength, 678 offset_noise_strength=args.offset_noise_strength,
672 # -- 679 # --
673 tokenizer=tokenizer, 680 tokenizer=tokenizer,