diff options
author | Volpeon <git@volpeon.ink> | 2023-03-28 16:24:22 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-28 16:24:22 +0200 |
commit | 179a45253a5b3712f32bd127f693a6bb810a9c17 (patch) | |
tree | ac9f1152d858089742e4f9ce79e0870e0f2b9a2d /train_lora.py | |
parent | Fix TI (diff) | |
download | textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.gz textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.bz2 textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.zip |
Support num_train_steps arg again
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index 7ecddf0..a9c6e52 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -4,6 +4,7 @@ import logging | |||
4 | import itertools | 4 | import itertools |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | 6 | from functools import partial |
7 | import math | ||
7 | 8 | ||
8 | import torch | 9 | import torch |
9 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
@@ -178,13 +179,12 @@ def parse_args(): | |||
178 | parser.add_argument( | 179 | parser.add_argument( |
179 | "--num_train_epochs", | 180 | "--num_train_epochs", |
180 | type=int, | 181 | type=int, |
181 | default=100 | 182 | default=None |
182 | ) | 183 | ) |
183 | parser.add_argument( | 184 | parser.add_argument( |
184 | "--max_train_steps", | 185 | "--num_train_steps", |
185 | type=int, | 186 | type=int, |
186 | default=None, | 187 | default=2000 |
187 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
188 | ) | 188 | ) |
189 | parser.add_argument( | 189 | parser.add_argument( |
190 | "--gradient_accumulation_steps", | 190 | "--gradient_accumulation_steps", |
@@ -627,6 +627,11 @@ def main(): | |||
627 | ) | 627 | ) |
628 | datamodule.setup() | 628 | datamodule.setup() |
629 | 629 | ||
630 | num_train_epochs = args.num_train_epochs | ||
631 | |||
632 | if num_train_epochs is None: | ||
633 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
634 | |||
630 | optimizer = create_optimizer( | 635 | optimizer = create_optimizer( |
631 | itertools.chain( | 636 | itertools.chain( |
632 | unet.parameters(), | 637 | unet.parameters(), |
@@ -647,7 +652,7 @@ def main(): | |||
647 | annealing_exp=args.lr_annealing_exp, | 652 | annealing_exp=args.lr_annealing_exp, |
648 | cycles=args.lr_cycles, | 653 | cycles=args.lr_cycles, |
649 | end_lr=1e2, | 654 | end_lr=1e2, |
650 | train_epochs=args.num_train_epochs, | 655 | train_epochs=num_train_epochs, |
651 | warmup_epochs=args.lr_warmup_epochs, | 656 | warmup_epochs=args.lr_warmup_epochs, |
652 | ) | 657 | ) |
653 | 658 | ||
@@ -659,7 +664,7 @@ def main(): | |||
659 | seed=args.seed, | 664 | seed=args.seed, |
660 | optimizer=optimizer, | 665 | optimizer=optimizer, |
661 | lr_scheduler=lr_scheduler, | 666 | lr_scheduler=lr_scheduler, |
662 | num_train_epochs=args.num_train_epochs, | 667 | num_train_epochs=num_train_epochs, |
663 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 668 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
664 | sample_frequency=args.sample_frequency, | 669 | sample_frequency=args.sample_frequency, |
665 | offset_noise_strength=args.offset_noise_strength, | 670 | offset_noise_strength=args.offset_noise_strength, |