diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 5 |
1 files changed, 1 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48921d4..f4d4cbb 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -18,7 +18,6 @@ import transformers | |||
18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
21 | from training.lr import plot_metrics | ||
22 | from training.strategy.dreambooth import dreambooth_strategy | 21 | from training.strategy.dreambooth import dreambooth_strategy |
23 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
24 | from training.util import save_args | 23 | from training.util import save_args |
@@ -692,7 +691,7 @@ def main(): | |||
692 | mid_point=args.lr_mid_point, | 691 | mid_point=args.lr_mid_point, |
693 | ) | 692 | ) |
694 | 693 | ||
695 | metrics = trainer( | 694 | trainer( |
696 | strategy=dreambooth_strategy, | 695 | strategy=dreambooth_strategy, |
697 | project="dreambooth", | 696 | project="dreambooth", |
698 | train_dataloader=datamodule.train_dataloader, | 697 | train_dataloader=datamodule.train_dataloader, |
@@ -721,8 +720,6 @@ def main(): | |||
721 | sample_image_size=args.sample_image_size, | 720 | sample_image_size=args.sample_image_size, |
722 | ) | 721 | ) |
723 | 722 | ||
724 | plot_metrics(metrics, output_dir / "lr.png") | ||
725 | |||
726 | 723 | ||
727 | if __name__ == "__main__": | 724 | if __name__ == "__main__": |
728 | main() | 725 | main() |