diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index 7ecddf0..a9c6e52 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -4,6 +4,7 @@ import logging | |||
| 4 | import itertools | 4 | import itertools |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | ||
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| 9 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| @@ -178,13 +179,12 @@ def parse_args(): | |||
| 178 | parser.add_argument( | 179 | parser.add_argument( |
| 179 | "--num_train_epochs", | 180 | "--num_train_epochs", |
| 180 | type=int, | 181 | type=int, |
| 181 | default=100 | 182 | default=None |
| 182 | ) | 183 | ) |
| 183 | parser.add_argument( | 184 | parser.add_argument( |
| 184 | "--max_train_steps", | 185 | "--num_train_steps", |
| 185 | type=int, | 186 | type=int, |
| 186 | default=None, | 187 | default=2000 |
| 187 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 188 | ) | 188 | ) |
| 189 | parser.add_argument( | 189 | parser.add_argument( |
| 190 | "--gradient_accumulation_steps", | 190 | "--gradient_accumulation_steps", |
| @@ -627,6 +627,11 @@ def main(): | |||
| 627 | ) | 627 | ) |
| 628 | datamodule.setup() | 628 | datamodule.setup() |
| 629 | 629 | ||
| 630 | num_train_epochs = args.num_train_epochs | ||
| 631 | |||
| 632 | if num_train_epochs is None: | ||
| 633 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
| 634 | |||
| 630 | optimizer = create_optimizer( | 635 | optimizer = create_optimizer( |
| 631 | itertools.chain( | 636 | itertools.chain( |
| 632 | unet.parameters(), | 637 | unet.parameters(), |
| @@ -647,7 +652,7 @@ def main(): | |||
| 647 | annealing_exp=args.lr_annealing_exp, | 652 | annealing_exp=args.lr_annealing_exp, |
| 648 | cycles=args.lr_cycles, | 653 | cycles=args.lr_cycles, |
| 649 | end_lr=1e2, | 654 | end_lr=1e2, |
| 650 | train_epochs=args.num_train_epochs, | 655 | train_epochs=num_train_epochs, |
| 651 | warmup_epochs=args.lr_warmup_epochs, | 656 | warmup_epochs=args.lr_warmup_epochs, |
| 652 | ) | 657 | ) |
| 653 | 658 | ||
| @@ -659,7 +664,7 @@ def main(): | |||
| 659 | seed=args.seed, | 664 | seed=args.seed, |
| 660 | optimizer=optimizer, | 665 | optimizer=optimizer, |
| 661 | lr_scheduler=lr_scheduler, | 666 | lr_scheduler=lr_scheduler, |
| 662 | num_train_epochs=args.num_train_epochs, | 667 | num_train_epochs=num_train_epochs, |
| 663 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 668 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 664 | sample_frequency=args.sample_frequency, | 669 | sample_frequency=args.sample_frequency, |
| 665 | offset_noise_strength=args.offset_noise_strength, | 670 | offset_noise_strength=args.offset_noise_strength, |
