summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py
index 7ecddf0..a9c6e52 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -4,6 +4,7 @@ import logging
4import itertools 4import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math
7 8
8import torch 9import torch
9import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -178,13 +179,12 @@ def parse_args():
178 parser.add_argument( 179 parser.add_argument(
179 "--num_train_epochs", 180 "--num_train_epochs",
180 type=int, 181 type=int,
181 default=100 182 default=None
182 ) 183 )
183 parser.add_argument( 184 parser.add_argument(
184 "--max_train_steps", 185 "--num_train_steps",
185 type=int, 186 type=int,
186 default=None, 187 default=2000
187 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
188 ) 188 )
189 parser.add_argument( 189 parser.add_argument(
190 "--gradient_accumulation_steps", 190 "--gradient_accumulation_steps",
@@ -627,6 +627,11 @@ def main():
627 ) 627 )
628 datamodule.setup() 628 datamodule.setup()
629 629
630 num_train_epochs = args.num_train_epochs
631
632 if num_train_epochs is None:
633 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
634
630 optimizer = create_optimizer( 635 optimizer = create_optimizer(
631 itertools.chain( 636 itertools.chain(
632 unet.parameters(), 637 unet.parameters(),
@@ -647,7 +652,7 @@ def main():
647 annealing_exp=args.lr_annealing_exp, 652 annealing_exp=args.lr_annealing_exp,
648 cycles=args.lr_cycles, 653 cycles=args.lr_cycles,
649 end_lr=1e2, 654 end_lr=1e2,
650 train_epochs=args.num_train_epochs, 655 train_epochs=num_train_epochs,
651 warmup_epochs=args.lr_warmup_epochs, 656 warmup_epochs=args.lr_warmup_epochs,
652 ) 657 )
653 658
@@ -659,7 +664,7 @@ def main():
659 seed=args.seed, 664 seed=args.seed,
660 optimizer=optimizer, 665 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 666 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs, 667 num_train_epochs=num_train_epochs,
663 gradient_accumulation_steps=args.gradient_accumulation_steps, 668 gradient_accumulation_steps=args.gradient_accumulation_steps,
664 sample_frequency=args.sample_frequency, 669 sample_frequency=args.sample_frequency,
665 offset_noise_strength=args.offset_noise_strength, 670 offset_noise_strength=args.offset_noise_strength,