diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/train_lora.py b/train_lora.py index 9975462..7b54ef8 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -314,6 +314,12 @@ def parse_args(): | |||
314 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 314 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
315 | ) | 315 | ) |
316 | parser.add_argument( | 316 | parser.add_argument( |
317 | "--dadaptation_d0", | ||
318 | type=float, | ||
319 | default=1e-6, | ||
320 | help="The d0 parameter for Dadaptation optimizers." | ||
321 | ) | ||
322 | parser.add_argument( | ||
317 | "--adam_beta1", | 323 | "--adam_beta1", |
318 | type=float, | 324 | type=float, |
319 | default=0.9, | 325 | default=0.9, |
@@ -567,6 +573,7 @@ def main(): | |||
567 | weight_decay=args.adam_weight_decay, | 573 | weight_decay=args.adam_weight_decay, |
568 | eps=args.adam_epsilon, | 574 | eps=args.adam_epsilon, |
569 | decouple=True, | 575 | decouple=True, |
576 | d0=args.dadaptation_d0, | ||
570 | ) | 577 | ) |
571 | 578 | ||
572 | args.learning_rate = 1.0 | 579 | args.learning_rate = 1.0 |
@@ -580,6 +587,7 @@ def main(): | |||
580 | dadaptation.DAdaptAdan, | 587 | dadaptation.DAdaptAdan, |
581 | weight_decay=args.adam_weight_decay, | 588 | weight_decay=args.adam_weight_decay, |
582 | eps=args.adam_epsilon, | 589 | eps=args.adam_epsilon, |
590 | d0=args.dadaptation_d0, | ||
583 | ) | 591 | ) |
584 | 592 | ||
585 | args.learning_rate = 1.0 | 593 | args.learning_rate = 1.0 |
@@ -628,10 +636,9 @@ def main(): | |||
628 | datamodule.setup() | 636 | datamodule.setup() |
629 | 637 | ||
630 | num_train_epochs = args.num_train_epochs | 638 | num_train_epochs = args.num_train_epochs |
631 | |||
632 | if num_train_epochs is None: | 639 | if num_train_epochs is None: |
633 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 640 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
634 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 641 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
635 | 642 | ||
636 | optimizer = create_optimizer( | 643 | optimizer = create_optimizer( |
637 | itertools.chain( | 644 | itertools.chain( |
@@ -667,7 +674,7 @@ def main(): | |||
667 | lr_scheduler=lr_scheduler, | 674 | lr_scheduler=lr_scheduler, |
668 | num_train_epochs=num_train_epochs, | 675 | num_train_epochs=num_train_epochs, |
669 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 676 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
670 | sample_frequency=args.sample_frequency, | 677 | sample_frequency=sample_frequency, |
671 | offset_noise_strength=args.offset_noise_strength, | 678 | offset_noise_strength=args.offset_noise_strength, |
672 | # -- | 679 | # -- |
673 | tokenizer=tokenizer, | 680 | tokenizer=tokenizer, |