diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9345797..acb8287 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -4,6 +4,7 @@ import logging | |||
| 4 | import itertools | 4 | import itertools |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | ||
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| 9 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| @@ -189,13 +190,12 @@ def parse_args(): | |||
| 189 | parser.add_argument( | 190 | parser.add_argument( |
| 190 | "--num_train_epochs", | 191 | "--num_train_epochs", |
| 191 | type=int, | 192 | type=int, |
| 192 | default=100 | 193 | default=None |
| 193 | ) | 194 | ) |
| 194 | parser.add_argument( | 195 | parser.add_argument( |
| 195 | "--max_train_steps", | 196 | "--num_train_steps", |
| 196 | type=int, | 197 | type=int, |
| 197 | default=None, | 198 | default=2000 |
| 198 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 199 | ) | 199 | ) |
| 200 | parser.add_argument( | 200 | parser.add_argument( |
| 201 | "--gradient_accumulation_steps", | 201 | "--gradient_accumulation_steps", |
| @@ -595,6 +595,11 @@ def main(): | |||
| 595 | ) | 595 | ) |
| 596 | datamodule.setup() | 596 | datamodule.setup() |
| 597 | 597 | ||
| 598 | num_train_epochs = args.num_train_epochs | ||
| 599 | |||
| 600 | if num_train_epochs is None: | ||
| 601 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
| 602 | |||
| 598 | params_to_optimize = (unet.parameters(), ) | 603 | params_to_optimize = (unet.parameters(), ) |
| 599 | if args.train_text_encoder_epochs != 0: | 604 | if args.train_text_encoder_epochs != 0: |
| 600 | params_to_optimize += ( | 605 | params_to_optimize += ( |
| @@ -619,7 +624,7 @@ def main(): | |||
| 619 | annealing_exp=args.lr_annealing_exp, | 624 | annealing_exp=args.lr_annealing_exp, |
| 620 | cycles=args.lr_cycles, | 625 | cycles=args.lr_cycles, |
| 621 | end_lr=1e2, | 626 | end_lr=1e2, |
| 622 | train_epochs=args.num_train_epochs, | 627 | train_epochs=num_train_epochs, |
| 623 | warmup_epochs=args.lr_warmup_epochs, | 628 | warmup_epochs=args.lr_warmup_epochs, |
| 624 | ) | 629 | ) |
| 625 | 630 | ||
| @@ -631,7 +636,7 @@ def main(): | |||
| 631 | seed=args.seed, | 636 | seed=args.seed, |
| 632 | optimizer=optimizer, | 637 | optimizer=optimizer, |
| 633 | lr_scheduler=lr_scheduler, | 638 | lr_scheduler=lr_scheduler, |
| 634 | num_train_epochs=args.num_train_epochs, | 639 | num_train_epochs=num_train_epochs, |
| 635 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 640 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 636 | sample_frequency=args.sample_frequency, | 641 | sample_frequency=args.sample_frequency, |
| 637 | offset_noise_strength=args.offset_noise_strength, | 642 | offset_noise_strength=args.offset_noise_strength, |
