From 2dfd1790078753f19ca8c585ac77079f3114f3a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 21:47:06 +0100 Subject: Training update --- train_dreambooth.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 6511f9b..d722e68 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -552,6 +552,9 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) + checkpoint_output_dir = output_dir.joinpath("model") + sample_output_dir = output_dir.joinpath(f"samples") + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -620,7 +623,8 @@ def main(): # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=output_dir, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, train_text_encoder_epochs=args.train_text_encoder_epochs, max_grad_norm=args.max_grad_norm, use_ema=args.use_ema, -- cgit v1.2.3-54-g00ecf