summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
commit5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch)
treeb1483a52fb853aecb7b73635cded3cce61edf125 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48921d4..f4d4cbb 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -18,7 +18,6 @@ import transformers
18from util.files import load_config, load_embeddings_from_dir 18from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 20from training.functional import train, get_models
21from training.lr import plot_metrics
22from training.strategy.dreambooth import dreambooth_strategy 21from training.strategy.dreambooth import dreambooth_strategy
23from training.optimization import get_scheduler 22from training.optimization import get_scheduler
24from training.util import save_args 23from training.util import save_args
@@ -692,7 +691,7 @@ def main():
692 mid_point=args.lr_mid_point, 691 mid_point=args.lr_mid_point,
693 ) 692 )
694 693
695 metrics = trainer( 694 trainer(
696 strategy=dreambooth_strategy, 695 strategy=dreambooth_strategy,
697 project="dreambooth", 696 project="dreambooth",
698 train_dataloader=datamodule.train_dataloader, 697 train_dataloader=datamodule.train_dataloader,
@@ -721,8 +720,6 @@ def main():
721 sample_image_size=args.sample_image_size, 720 sample_image_size=args.sample_image_size,
722 ) 721 )
723 722
724 plot_metrics(metrics, output_dir / "lr.png")
725
726 723
727if __name__ == "__main__": 724if __name__ == "__main__":
728 main() 725 main()