From 5210c15fd812328f8f0d7c95d3ed4ec41bdf6444 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 18:10:12 +0200 Subject: Supply dataset CSV file instead of dir with hardcoded CSV filename --- dreambooth.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 89ed96a..45a0497 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -46,7 +46,7 @@ def parse_args(): help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( - "--train_data_dir", + "--train_data_file", type=str, default=None, help="A folder containing the training data." @@ -269,8 +269,8 @@ def parse_args(): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.train_data_dir is None: - raise ValueError("You must specify --train_data_dir") + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") @@ -587,7 +587,7 @@ def main(): return batch datamodule = CSVDataModule( - data_root=args.train_data_dir, + data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, instance_prompt=args.instance_prompt, @@ -680,12 +680,12 @@ def main(): 0, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Global steps") - local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) local_progress_bar.set_description("Steps") + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Global steps") + try: for epoch in range(args.num_train_epochs): local_progress_bar.reset() @@ -733,14 +733,14 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + progress_bar.update(1) global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: - progress_bar.clear() local_progress_bar.clear() + progress_bar.clear() checkpointer.save_samples( "training", @@ -782,8 +782,8 @@ def main(): val_loss += loss if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) @@ -792,8 +792,8 @@ def main(): accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) - progress_bar.clear() local_progress_bar.clear() + progress_bar.clear() if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") -- cgit v1.2.3-54-g00ecf