diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6511f9b..d722e68 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -552,6 +552,9 @@ def main(): | |||
552 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
553 | ) | 553 | ) |
554 | 554 | ||
555 | checkpoint_output_dir = output_dir.joinpath("model") | ||
556 | sample_output_dir = output_dir.joinpath(f"samples") | ||
557 | |||
555 | datamodule = VlpnDataModule( | 558 | datamodule = VlpnDataModule( |
556 | data_file=args.train_data_file, | 559 | data_file=args.train_data_file, |
557 | batch_size=args.train_batch_size, | 560 | batch_size=args.train_batch_size, |
@@ -620,7 +623,8 @@ def main(): | |||
620 | # -- | 623 | # -- |
621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |
622 | sample_scheduler=sample_scheduler, | 625 | sample_scheduler=sample_scheduler, |
623 | output_dir=output_dir, | 626 | sample_output_dir=sample_output_dir, |
627 | checkpoint_output_dir=checkpoint_output_dir, | ||
624 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 628 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
625 | max_grad_norm=args.max_grad_norm, | 629 | max_grad_norm=args.max_grad_norm, |
626 | use_ema=args.use_ema, | 630 | use_ema=args.use_ema, |