summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
commit2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch)
treed1d1d643f247767c13535105dbe4afafcc5ab8c0 /train_dreambooth.py
parentIf valid set size is 0, re-use one image from train set (diff)
downloadtextual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip
Training update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6511f9b..d722e68 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -552,6 +552,9 @@ def main():
552 prior_loss_weight=args.prior_loss_weight, 552 prior_loss_weight=args.prior_loss_weight,
553 ) 553 )
554 554
555 checkpoint_output_dir = output_dir.joinpath("model")
556 sample_output_dir = output_dir.joinpath(f"samples")
557
555 datamodule = VlpnDataModule( 558 datamodule = VlpnDataModule(
556 data_file=args.train_data_file, 559 data_file=args.train_data_file,
557 batch_size=args.train_batch_size, 560 batch_size=args.train_batch_size,
@@ -620,7 +623,8 @@ def main():
620 # -- 623 # --
621 tokenizer=tokenizer, 624 tokenizer=tokenizer,
622 sample_scheduler=sample_scheduler, 625 sample_scheduler=sample_scheduler,
623 output_dir=output_dir, 626 sample_output_dir=sample_output_dir,
627 checkpoint_output_dir=checkpoint_output_dir,
624 train_text_encoder_epochs=args.train_text_encoder_epochs, 628 train_text_encoder_epochs=args.train_text_encoder_epochs,
625 max_grad_norm=args.max_grad_norm, 629 max_grad_norm=args.max_grad_norm,
626 use_ema=args.use_ema, 630 use_ema=args.use_ema,