summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6511f9b..d722e68 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -552,6 +552,9 @@ def main():
552 prior_loss_weight=args.prior_loss_weight, 552 prior_loss_weight=args.prior_loss_weight,
553 ) 553 )
554 554
555 checkpoint_output_dir = output_dir.joinpath("model")
556 sample_output_dir = output_dir.joinpath(f"samples")
557
555 datamodule = VlpnDataModule( 558 datamodule = VlpnDataModule(
556 data_file=args.train_data_file, 559 data_file=args.train_data_file,
557 batch_size=args.train_batch_size, 560 batch_size=args.train_batch_size,
@@ -620,7 +623,8 @@ def main():
620 # -- 623 # --
621 tokenizer=tokenizer, 624 tokenizer=tokenizer,
622 sample_scheduler=sample_scheduler, 625 sample_scheduler=sample_scheduler,
623 output_dir=output_dir, 626 sample_output_dir=sample_output_dir,
627 checkpoint_output_dir=checkpoint_output_dir,
624 train_text_encoder_epochs=args.train_text_encoder_epochs, 628 train_text_encoder_epochs=args.train_text_encoder_epochs,
625 max_grad_norm=args.max_grad_norm, 629 max_grad_norm=args.max_grad_norm,
626 use_ema=args.use_ema, 630 use_ema=args.use_ema,