diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9c1e41c..a70c80e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -16,6 +16,7 @@ from slugify import slugify | |||
16 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
18 | from training.functional import train, get_models | 18 | from training.functional import train, get_models |
19 | from training.lr import plot_metrics | ||
19 | from training.strategy.dreambooth import dreambooth_strategy | 20 | from training.strategy.dreambooth import dreambooth_strategy |
20 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
21 | from training.util import save_args | 22 | from training.util import save_args |
@@ -524,6 +525,10 @@ def main(): | |||
524 | args.train_batch_size * accelerator.num_processes | 525 | args.train_batch_size * accelerator.num_processes |
525 | ) | 526 | ) |
526 | 527 | ||
528 | if args.find_lr: | ||
529 | args.learning_rate = 1e-6 | ||
530 | args.lr_scheduler = "exponential_growth" | ||
531 | |||
527 | if args.use_8bit_adam: | 532 | if args.use_8bit_adam: |
528 | try: | 533 | try: |
529 | import bitsandbytes as bnb | 534 | import bitsandbytes as bnb |
@@ -602,11 +607,12 @@ def main(): | |||
602 | warmup_exp=args.lr_warmup_exp, | 607 | warmup_exp=args.lr_warmup_exp, |
603 | annealing_exp=args.lr_annealing_exp, | 608 | annealing_exp=args.lr_annealing_exp, |
604 | cycles=args.lr_cycles, | 609 | cycles=args.lr_cycles, |
610 | end_lr=1e2, | ||
605 | train_epochs=args.num_train_epochs, | 611 | train_epochs=args.num_train_epochs, |
606 | warmup_epochs=args.lr_warmup_epochs, | 612 | warmup_epochs=args.lr_warmup_epochs, |
607 | ) | 613 | ) |
608 | 614 | ||
609 | trainer( | 615 | metrics = trainer( |
610 | strategy=dreambooth_strategy, | 616 | strategy=dreambooth_strategy, |
611 | project="dreambooth", | 617 | project="dreambooth", |
612 | train_dataloader=datamodule.train_dataloader, | 618 | train_dataloader=datamodule.train_dataloader, |
@@ -634,6 +640,8 @@ def main(): | |||
634 | sample_image_size=args.sample_image_size, | 640 | sample_image_size=args.sample_image_size, |
635 | ) | 641 | ) |
636 | 642 | ||
643 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
644 | |||
637 | 645 | ||
638 | if __name__ == "__main__": | 646 | if __name__ == "__main__": |
639 | main() | 647 | main() |