diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
| commit | 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch) | |
| tree | b1483a52fb853aecb7b73635cded3cce61edf125 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2 textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 5 |
1 files changed, 1 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48921d4..f4d4cbb 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -18,7 +18,6 @@ import transformers | |||
| 18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
| 21 | from training.lr import plot_metrics | ||
| 22 | from training.strategy.dreambooth import dreambooth_strategy | 21 | from training.strategy.dreambooth import dreambooth_strategy |
| 23 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
| 24 | from training.util import save_args | 23 | from training.util import save_args |
| @@ -692,7 +691,7 @@ def main(): | |||
| 692 | mid_point=args.lr_mid_point, | 691 | mid_point=args.lr_mid_point, |
| 693 | ) | 692 | ) |
| 694 | 693 | ||
| 695 | metrics = trainer( | 694 | trainer( |
| 696 | strategy=dreambooth_strategy, | 695 | strategy=dreambooth_strategy, |
| 697 | project="dreambooth", | 696 | project="dreambooth", |
| 698 | train_dataloader=datamodule.train_dataloader, | 697 | train_dataloader=datamodule.train_dataloader, |
| @@ -721,8 +720,6 @@ def main(): | |||
| 721 | sample_image_size=args.sample_image_size, | 720 | sample_image_size=args.sample_image_size, |
| 722 | ) | 721 | ) |
| 723 | 722 | ||
| 724 | plot_metrics(metrics, output_dir / "lr.png") | ||
| 725 | |||
| 726 | 723 | ||
| 727 | if __name__ == "__main__": | 724 | if __name__ == "__main__": |
| 728 | main() | 725 | main() |
