diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 89ed96a..45a0497 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -46,7 +46,7 @@ def parse_args(): | |||
46 | help="Pretrained tokenizer name or path if not the same as model_name", | 46 | help="Pretrained tokenizer name or path if not the same as model_name", |
47 | ) | 47 | ) |
48 | parser.add_argument( | 48 | parser.add_argument( |
49 | "--train_data_dir", | 49 | "--train_data_file", |
50 | type=str, | 50 | type=str, |
51 | default=None, | 51 | default=None, |
52 | help="A folder containing the training data." | 52 | help="A folder containing the training data." |
@@ -269,8 +269,8 @@ def parse_args(): | |||
269 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
270 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
271 | 271 | ||
272 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
273 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
274 | 274 | ||
275 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
276 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
@@ -587,7 +587,7 @@ def main(): | |||
587 | return batch | 587 | return batch |
588 | 588 | ||
589 | datamodule = CSVDataModule( | 589 | datamodule = CSVDataModule( |
590 | data_root=args.train_data_dir, | 590 | data_file=args.train_data_file, |
591 | batch_size=args.train_batch_size, | 591 | batch_size=args.train_batch_size, |
592 | tokenizer=tokenizer, | 592 | tokenizer=tokenizer, |
593 | instance_prompt=args.instance_prompt, | 593 | instance_prompt=args.instance_prompt, |
@@ -680,12 +680,12 @@ def main(): | |||
680 | 0, | 680 | 0, |
681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
682 | 682 | ||
683 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
684 | progress_bar.set_description("Global steps") | ||
685 | |||
686 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 683 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
687 | local_progress_bar.set_description("Steps") | 684 | local_progress_bar.set_description("Steps") |
688 | 685 | ||
686 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
687 | progress_bar.set_description("Global steps") | ||
688 | |||
689 | try: | 689 | try: |
690 | for epoch in range(args.num_train_epochs): | 690 | for epoch in range(args.num_train_epochs): |
691 | local_progress_bar.reset() | 691 | local_progress_bar.reset() |
@@ -733,14 +733,14 @@ def main(): | |||
733 | 733 | ||
734 | # Checks if the accelerator has performed an optimization step behind the scenes | 734 | # Checks if the accelerator has performed an optimization step behind the scenes |
735 | if accelerator.sync_gradients: | 735 | if accelerator.sync_gradients: |
736 | progress_bar.update(1) | ||
737 | local_progress_bar.update(1) | 736 | local_progress_bar.update(1) |
737 | progress_bar.update(1) | ||
738 | 738 | ||
739 | global_step += 1 | 739 | global_step += 1 |
740 | 740 | ||
741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
742 | progress_bar.clear() | ||
743 | local_progress_bar.clear() | 742 | local_progress_bar.clear() |
743 | progress_bar.clear() | ||
744 | 744 | ||
745 | checkpointer.save_samples( | 745 | checkpointer.save_samples( |
746 | "training", | 746 | "training", |
@@ -782,8 +782,8 @@ def main(): | |||
782 | val_loss += loss | 782 | val_loss += loss |
783 | 783 | ||
784 | if accelerator.sync_gradients: | 784 | if accelerator.sync_gradients: |
785 | progress_bar.update(1) | ||
786 | local_progress_bar.update(1) | 785 | local_progress_bar.update(1) |
786 | progress_bar.update(1) | ||
787 | 787 | ||
788 | logs = {"mode": "validation", "loss": loss} | 788 | logs = {"mode": "validation", "loss": loss} |
789 | local_progress_bar.set_postfix(**logs) | 789 | local_progress_bar.set_postfix(**logs) |
@@ -792,8 +792,8 @@ def main(): | |||
792 | 792 | ||
793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
794 | 794 | ||
795 | progress_bar.clear() | ||
796 | local_progress_bar.clear() | 795 | local_progress_bar.clear() |
796 | progress_bar.clear() | ||
797 | 797 | ||
798 | if min_val_loss > val_loss: | 798 | if min_val_loss > val_loss: |
799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |