summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 18:10:12 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 18:10:12 +0200
commit5210c15fd812328f8f0d7c95d3ed4ec41bdf6444 (patch)
treea37e52b23393aadc49378230c3eb1f12865d549d /dreambooth.py
parentFreeze models that aren't trained (diff)
downloadtextual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.tar.gz
textual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.tar.bz2
textual-inversion-diff-5210c15fd812328f8f0d7c95d3ed4ec41bdf6444.zip
Supply dataset CSV file instead of dir with hardcoded CSV filename
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 89ed96a..45a0497 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -46,7 +46,7 @@ def parse_args():
46 help="Pretrained tokenizer name or path if not the same as model_name", 46 help="Pretrained tokenizer name or path if not the same as model_name",
47 ) 47 )
48 parser.add_argument( 48 parser.add_argument(
49 "--train_data_dir", 49 "--train_data_file",
50 type=str, 50 type=str,
51 default=None, 51 default=None,
52 help="A folder containing the training data." 52 help="A folder containing the training data."
@@ -269,8 +269,8 @@ def parse_args():
269 if env_local_rank != -1 and env_local_rank != args.local_rank: 269 if env_local_rank != -1 and env_local_rank != args.local_rank:
270 args.local_rank = env_local_rank 270 args.local_rank = env_local_rank
271 271
272 if args.train_data_dir is None: 272 if args.train_data_file is None:
273 raise ValueError("You must specify --train_data_dir") 273 raise ValueError("You must specify --train_data_file")
274 274
275 if args.pretrained_model_name_or_path is None: 275 if args.pretrained_model_name_or_path is None:
276 raise ValueError("You must specify --pretrained_model_name_or_path") 276 raise ValueError("You must specify --pretrained_model_name_or_path")
@@ -587,7 +587,7 @@ def main():
587 return batch 587 return batch
588 588
589 datamodule = CSVDataModule( 589 datamodule = CSVDataModule(
590 data_root=args.train_data_dir, 590 data_file=args.train_data_file,
591 batch_size=args.train_batch_size, 591 batch_size=args.train_batch_size,
592 tokenizer=tokenizer, 592 tokenizer=tokenizer,
593 instance_prompt=args.instance_prompt, 593 instance_prompt=args.instance_prompt,
@@ -680,12 +680,12 @@ def main():
680 0, 680 0,
681 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 681 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
682 682
683 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
684 progress_bar.set_description("Global steps")
685
686 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 683 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
687 local_progress_bar.set_description("Steps") 684 local_progress_bar.set_description("Steps")
688 685
686 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
687 progress_bar.set_description("Global steps")
688
689 try: 689 try:
690 for epoch in range(args.num_train_epochs): 690 for epoch in range(args.num_train_epochs):
691 local_progress_bar.reset() 691 local_progress_bar.reset()
@@ -733,14 +733,14 @@ def main():
733 733
734 # Checks if the accelerator has performed an optimization step behind the scenes 734 # Checks if the accelerator has performed an optimization step behind the scenes
735 if accelerator.sync_gradients: 735 if accelerator.sync_gradients:
736 progress_bar.update(1)
737 local_progress_bar.update(1) 736 local_progress_bar.update(1)
737 progress_bar.update(1)
738 738
739 global_step += 1 739 global_step += 1
740 740
741 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 741 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
742 progress_bar.clear()
743 local_progress_bar.clear() 742 local_progress_bar.clear()
743 progress_bar.clear()
744 744
745 checkpointer.save_samples( 745 checkpointer.save_samples(
746 "training", 746 "training",
@@ -782,8 +782,8 @@ def main():
782 val_loss += loss 782 val_loss += loss
783 783
784 if accelerator.sync_gradients: 784 if accelerator.sync_gradients:
785 progress_bar.update(1)
786 local_progress_bar.update(1) 785 local_progress_bar.update(1)
786 progress_bar.update(1)
787 787
788 logs = {"mode": "validation", "loss": loss} 788 logs = {"mode": "validation", "loss": loss}
789 local_progress_bar.set_postfix(**logs) 789 local_progress_bar.set_postfix(**logs)
@@ -792,8 +792,8 @@ def main():
792 792
793 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 793 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
794 794
795 progress_bar.clear()
796 local_progress_bar.clear() 795 local_progress_bar.clear()
796 progress_bar.clear()
797 797
798 if min_val_loss > val_loss: 798 if min_val_loss > val_loss:
799 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 799 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")