summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py24
-rw-r--r--dreambooth.py22
2 files changed, 25 insertions, 21 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 04df4c6..e70c068 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -11,7 +11,7 @@ from torchvision import transforms
11class CSVDataModule(pl.LightningDataModule): 11class CSVDataModule(pl.LightningDataModule):
12 def __init__(self, 12 def __init__(self,
13 batch_size, 13 batch_size,
14 data_root, 14 data_file,
15 tokenizer, 15 tokenizer,
16 instance_prompt, 16 instance_prompt,
17 class_data_root=None, 17 class_data_root=None,
@@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule):
24 collate_fn=None): 24 collate_fn=None):
25 super().__init__() 25 super().__init__()
26 26
27 self.data_root = data_root 27 self.data_file = Path(data_file)
28
29 if not self.data_file.is_file():
30 raise ValueError("data_file must be a file")
31
32 self.data_root = self.data_file.parent
28 self.tokenizer = tokenizer 33 self.tokenizer = tokenizer
29 self.instance_prompt = instance_prompt 34 self.instance_prompt = instance_prompt
30 self.class_data_root = class_data_root 35 self.class_data_root = class_data_root
@@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule):
38 self.batch_size = batch_size 43 self.batch_size = batch_size
39 44
40 def prepare_data(self): 45 def prepare_data(self):
41 metadata = pd.read_csv(f'{self.data_root}/list.csv') 46 metadata = pd.read_csv(self.data_file)
42 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 47 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
43 captions = [caption for caption in metadata['caption'].values] 48 captions = [caption for caption in metadata['caption'].values]
44 skips = [skip for skip in metadata['skip'].values] 49 skips = [skip for skip in metadata['skip'].values]
@@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule):
50 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 55 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
51 56
52 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 57 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
53 class_data_root=self.class_data_root, 58 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
54 class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, 59 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
55 interpolation=self.interpolation, identifier=self.identifier, 60 center_crop=self.center_crop, repeats=self.repeats)
56 center_crop=self.center_crop)
57 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 61 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root, 62 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
59 class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, 63 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
60 identifier=self.identifier, center_crop=self.center_crop) 64 center_crop=self.center_crop)
61 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
62 shuffle=True, collate_fn=self.collate_fn) 66 shuffle=True, collate_fn=self.collate_fn)
63 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
diff --git a/dreambooth.py b/dreambooth.py
index 89ed96a..45a0497 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -46,7 +46,7 @@ def parse_args():
46 help="Pretrained tokenizer name or path if not the same as model_name", 46 help="Pretrained tokenizer name or path if not the same as model_name",
47 ) 47 )
48 parser.add_argument( 48 parser.add_argument(
49 "--train_data_dir", 49 "--train_data_file",
50 type=str, 50 type=str,
51 default=None, 51 default=None,
52 help="A folder containing the training data." 52 help="A folder containing the training data."
@@ -269,8 +269,8 @@ def parse_args():
269 if env_local_rank != -1 and env_local_rank != args.local_rank: 269 if env_local_rank != -1 and env_local_rank != args.local_rank:
270 args.local_rank = env_local_rank 270 args.local_rank = env_local_rank
271 271
272 if args.train_data_dir is None: 272 if args.train_data_file is None:
273 raise ValueError("You must specify --train_data_dir") 273 raise ValueError("You must specify --train_data_file")
274 274
275 if args.pretrained_model_name_or_path is None: 275 if args.pretrained_model_name_or_path is None:
276 raise ValueError("You must specify --pretrained_model_name_or_path") 276 raise ValueError("You must specify --pretrained_model_name_or_path")
@@ -587,7 +587,7 @@ def main():
587 return batch 587 return batch
588 588
589 datamodule = CSVDataModule( 589 datamodule = CSVDataModule(
590 data_root=args.train_data_dir, 590 data_file=args.train_data_file,
591 batch_size=args.train_batch_size, 591 batch_size=args.train_batch_size,
592 tokenizer=tokenizer, 592 tokenizer=tokenizer,
593 instance_prompt=args.instance_prompt, 593 instance_prompt=args.instance_prompt,
@@ -680,12 +680,12 @@ def main():
680 0, 680 0,
681 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 681 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
682 682
683 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
684 progress_bar.set_description("Global steps")
685
686 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 683 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
687 local_progress_bar.set_description("Steps") 684 local_progress_bar.set_description("Steps")
688 685
686 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
687 progress_bar.set_description("Global steps")
688
689 try: 689 try:
690 for epoch in range(args.num_train_epochs): 690 for epoch in range(args.num_train_epochs):
691 local_progress_bar.reset() 691 local_progress_bar.reset()
@@ -733,14 +733,14 @@ def main():
733 733
734 # Checks if the accelerator has performed an optimization step behind the scenes 734 # Checks if the accelerator has performed an optimization step behind the scenes
735 if accelerator.sync_gradients: 735 if accelerator.sync_gradients:
736 progress_bar.update(1)
737 local_progress_bar.update(1) 736 local_progress_bar.update(1)
737 progress_bar.update(1)
738 738
739 global_step += 1 739 global_step += 1
740 740
741 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 741 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
742 progress_bar.clear()
743 local_progress_bar.clear() 742 local_progress_bar.clear()
743 progress_bar.clear()
744 744
745 checkpointer.save_samples( 745 checkpointer.save_samples(
746 "training", 746 "training",
@@ -782,8 +782,8 @@ def main():
782 val_loss += loss 782 val_loss += loss
783 783
784 if accelerator.sync_gradients: 784 if accelerator.sync_gradients:
785 progress_bar.update(1)
786 local_progress_bar.update(1) 785 local_progress_bar.update(1)
786 progress_bar.update(1)
787 787
788 logs = {"mode": "validation", "loss": loss} 788 logs = {"mode": "validation", "loss": loss}
789 local_progress_bar.set_postfix(**logs) 789 local_progress_bar.set_postfix(**logs)
@@ -792,8 +792,8 @@ def main():
792 792
793 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 793 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
794 794
795 progress_bar.clear()
796 local_progress_bar.clear() 795 local_progress_bar.clear()
796 progress_bar.clear()
797 797
798 if min_val_loss > val_loss: 798 if min_val_loss > val_loss:
799 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 799 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")