diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6511f9b..d722e68 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -552,6 +552,9 @@ def main(): | |||
| 552 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
| 553 | ) | 553 | ) |
| 554 | 554 | ||
| 555 | checkpoint_output_dir = output_dir.joinpath("model") | ||
| 556 | sample_output_dir = output_dir.joinpath(f"samples") | ||
| 557 | |||
| 555 | datamodule = VlpnDataModule( | 558 | datamodule = VlpnDataModule( |
| 556 | data_file=args.train_data_file, | 559 | data_file=args.train_data_file, |
| 557 | batch_size=args.train_batch_size, | 560 | batch_size=args.train_batch_size, |
| @@ -620,7 +623,8 @@ def main(): | |||
| 620 | # -- | 623 | # -- |
| 621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |
| 622 | sample_scheduler=sample_scheduler, | 625 | sample_scheduler=sample_scheduler, |
| 623 | output_dir=output_dir, | 626 | sample_output_dir=sample_output_dir, |
| 627 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 624 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 628 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
| 625 | max_grad_norm=args.max_grad_norm, | 629 | max_grad_norm=args.max_grad_norm, |
| 626 | use_ema=args.use_ema, | 630 | use_ema=args.use_ema, |
