summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9c1e41c..a70c80e 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -16,6 +16,7 @@ from slugify import slugify
16from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, keyword_filter 17from data.csv import VlpnDataModule, keyword_filter
18from training.functional import train, get_models 18from training.functional import train, get_models
19from training.lr import plot_metrics
19from training.strategy.dreambooth import dreambooth_strategy 20from training.strategy.dreambooth import dreambooth_strategy
20from training.optimization import get_scheduler 21from training.optimization import get_scheduler
21from training.util import save_args 22from training.util import save_args
@@ -524,6 +525,10 @@ def main():
524 args.train_batch_size * accelerator.num_processes 525 args.train_batch_size * accelerator.num_processes
525 ) 526 )
526 527
528 if args.find_lr:
529 args.learning_rate = 1e-6
530 args.lr_scheduler = "exponential_growth"
531
527 if args.use_8bit_adam: 532 if args.use_8bit_adam:
528 try: 533 try:
529 import bitsandbytes as bnb 534 import bitsandbytes as bnb
@@ -602,11 +607,12 @@ def main():
602 warmup_exp=args.lr_warmup_exp, 607 warmup_exp=args.lr_warmup_exp,
603 annealing_exp=args.lr_annealing_exp, 608 annealing_exp=args.lr_annealing_exp,
604 cycles=args.lr_cycles, 609 cycles=args.lr_cycles,
610 end_lr=1e2,
605 train_epochs=args.num_train_epochs, 611 train_epochs=args.num_train_epochs,
606 warmup_epochs=args.lr_warmup_epochs, 612 warmup_epochs=args.lr_warmup_epochs,
607 ) 613 )
608 614
609 trainer( 615 metrics = trainer(
610 strategy=dreambooth_strategy, 616 strategy=dreambooth_strategy,
611 project="dreambooth", 617 project="dreambooth",
612 train_dataloader=datamodule.train_dataloader, 618 train_dataloader=datamodule.train_dataloader,
@@ -634,6 +640,8 @@ def main():
634 sample_image_size=args.sample_image_size, 640 sample_image_size=args.sample_image_size,
635 ) 641 )
636 642
643 plot_metrics(metrics, output_dir.joinpath("lr.png"))
644
637 645
638if __name__ == "__main__": 646if __name__ == "__main__":
639 main() 647 main()