diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 18:10:12 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 18:10:12 +0200 |
| commit | 5210c15fd812328f8f0d7c95d3ed4ec41bdf6444 (patch) | |
| tree | a37e52b23393aadc49378230c3eb1f12865d549d /dreambooth.py | |
| parent | Freeze models that aren't trained (diff) | |
| download | textual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.tar.gz textual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.tar.bz2 textual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.zip | |
Supply dataset CSV file instead of dir with hardcoded CSV filename
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 89ed96a..45a0497 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -46,7 +46,7 @@ def parse_args(): | |||
| 46 | help="Pretrained tokenizer name or path if not the same as model_name", | 46 | help="Pretrained tokenizer name or path if not the same as model_name", |
| 47 | ) | 47 | ) |
| 48 | parser.add_argument( | 48 | parser.add_argument( |
| 49 | "--train_data_dir", | 49 | "--train_data_file", |
| 50 | type=str, | 50 | type=str, |
| 51 | default=None, | 51 | default=None, |
| 52 | help="A folder containing the training data." | 52 | help="A folder containing the training data." |
| @@ -269,8 +269,8 @@ def parse_args(): | |||
| 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| 270 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
| 271 | 271 | ||
| 272 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
| 273 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
| 274 | 274 | ||
| 275 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
| 276 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| @@ -587,7 +587,7 @@ def main(): | |||
| 587 | return batch | 587 | return batch |
| 588 | 588 | ||
| 589 | datamodule = CSVDataModule( | 589 | datamodule = CSVDataModule( |
| 590 | data_root=args.train_data_dir, | 590 | data_file=args.train_data_file, |
| 591 | batch_size=args.train_batch_size, | 591 | batch_size=args.train_batch_size, |
| 592 | tokenizer=tokenizer, | 592 | tokenizer=tokenizer, |
| 593 | instance_prompt=args.instance_prompt, | 593 | instance_prompt=args.instance_prompt, |
| @@ -680,12 +680,12 @@ def main(): | |||
| 680 | 0, | 680 | 0, |
| 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 682 | 682 | ||
| 683 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
| 684 | progress_bar.set_description("Global steps") | ||
| 685 | |||
| 686 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 683 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
| 687 | local_progress_bar.set_description("Steps") | 684 | local_progress_bar.set_description("Steps") |
| 688 | 685 | ||
| 686 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
| 687 | progress_bar.set_description("Global steps") | ||
| 688 | |||
| 689 | try: | 689 | try: |
| 690 | for epoch in range(args.num_train_epochs): | 690 | for epoch in range(args.num_train_epochs): |
| 691 | local_progress_bar.reset() | 691 | local_progress_bar.reset() |
| @@ -733,14 +733,14 @@ def main(): | |||
| 733 | 733 | ||
| 734 | # Checks if the accelerator has performed an optimization step behind the scenes | 734 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 735 | if accelerator.sync_gradients: | 735 | if accelerator.sync_gradients: |
| 736 | progress_bar.update(1) | ||
| 737 | local_progress_bar.update(1) | 736 | local_progress_bar.update(1) |
| 737 | progress_bar.update(1) | ||
| 738 | 738 | ||
| 739 | global_step += 1 | 739 | global_step += 1 |
| 740 | 740 | ||
| 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
| 742 | progress_bar.clear() | ||
| 743 | local_progress_bar.clear() | 742 | local_progress_bar.clear() |
| 743 | progress_bar.clear() | ||
| 744 | 744 | ||
| 745 | checkpointer.save_samples( | 745 | checkpointer.save_samples( |
| 746 | "training", | 746 | "training", |
| @@ -782,8 +782,8 @@ def main(): | |||
| 782 | val_loss += loss | 782 | val_loss += loss |
| 783 | 783 | ||
| 784 | if accelerator.sync_gradients: | 784 | if accelerator.sync_gradients: |
| 785 | progress_bar.update(1) | ||
| 786 | local_progress_bar.update(1) | 785 | local_progress_bar.update(1) |
| 786 | progress_bar.update(1) | ||
| 787 | 787 | ||
| 788 | logs = {"mode": "validation", "loss": loss} | 788 | logs = {"mode": "validation", "loss": loss} |
| 789 | local_progress_bar.set_postfix(**logs) | 789 | local_progress_bar.set_postfix(**logs) |
| @@ -792,8 +792,8 @@ def main(): | |||
| 792 | 792 | ||
| 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
| 794 | 794 | ||
| 795 | progress_bar.clear() | ||
| 796 | local_progress_bar.clear() | 795 | local_progress_bar.clear() |
| 796 | progress_bar.clear() | ||
| 797 | 797 | ||
| 798 | if min_val_loss > val_loss: | 798 | if min_val_loss > val_loss: |
| 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
