diff options
-rw-r--r-- | data/dreambooth/csv.py | 24 | ||||
-rw-r--r-- | dreambooth.py | 22 |
2 files changed, 25 insertions, 21 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 04df4c6..e70c068 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -11,7 +11,7 @@ from torchvision import transforms | |||
11 | class CSVDataModule(pl.LightningDataModule): | 11 | class CSVDataModule(pl.LightningDataModule): |
12 | def __init__(self, | 12 | def __init__(self, |
13 | batch_size, | 13 | batch_size, |
14 | data_root, | 14 | data_file, |
15 | tokenizer, | 15 | tokenizer, |
16 | instance_prompt, | 16 | instance_prompt, |
17 | class_data_root=None, | 17 | class_data_root=None, |
@@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
24 | collate_fn=None): | 24 | collate_fn=None): |
25 | super().__init__() | 25 | super().__init__() |
26 | 26 | ||
27 | self.data_root = data_root | 27 | self.data_file = Path(data_file) |
28 | |||
29 | if not self.data_file.is_file(): | ||
30 | raise ValueError("data_file must be a file") | ||
31 | |||
32 | self.data_root = self.data_file.parent | ||
28 | self.tokenizer = tokenizer | 33 | self.tokenizer = tokenizer |
29 | self.instance_prompt = instance_prompt | 34 | self.instance_prompt = instance_prompt |
30 | self.class_data_root = class_data_root | 35 | self.class_data_root = class_data_root |
@@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
38 | self.batch_size = batch_size | 43 | self.batch_size = batch_size |
39 | 44 | ||
40 | def prepare_data(self): | 45 | def prepare_data(self): |
41 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 46 | metadata = pd.read_csv(self.data_file) |
42 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 47 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
43 | captions = [caption for caption in metadata['caption'].values] | 48 | captions = [caption for caption in metadata['caption'].values] |
44 | skips = [skip for skip in metadata['skip'].values] | 49 | skips = [skip for skip in metadata['skip'].values] |
@@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule): | |||
50 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 55 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
51 | 56 | ||
52 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
53 | class_data_root=self.class_data_root, | 58 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
54 | class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, | 59 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
55 | interpolation=self.interpolation, identifier=self.identifier, | 60 | center_crop=self.center_crop, repeats=self.repeats) |
56 | center_crop=self.center_crop) | ||
57 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, |
58 | class_data_root=self.class_data_root, | 62 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
59 | class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, | 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
60 | identifier=self.identifier, center_crop=self.center_crop) | 64 | center_crop=self.center_crop) |
61 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
62 | shuffle=True, collate_fn=self.collate_fn) | 66 | shuffle=True, collate_fn=self.collate_fn) |
63 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) |
diff --git a/dreambooth.py b/dreambooth.py index 89ed96a..45a0497 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -46,7 +46,7 @@ def parse_args(): | |||
46 | help="Pretrained tokenizer name or path if not the same as model_name", | 46 | help="Pretrained tokenizer name or path if not the same as model_name", |
47 | ) | 47 | ) |
48 | parser.add_argument( | 48 | parser.add_argument( |
49 | "--train_data_dir", | 49 | "--train_data_file", |
50 | type=str, | 50 | type=str, |
51 | default=None, | 51 | default=None, |
52 | help="A folder containing the training data." | 52 | help="A folder containing the training data." |
@@ -269,8 +269,8 @@ def parse_args(): | |||
269 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
270 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
271 | 271 | ||
272 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
273 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
274 | 274 | ||
275 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
276 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
@@ -587,7 +587,7 @@ def main(): | |||
587 | return batch | 587 | return batch |
588 | 588 | ||
589 | datamodule = CSVDataModule( | 589 | datamodule = CSVDataModule( |
590 | data_root=args.train_data_dir, | 590 | data_file=args.train_data_file, |
591 | batch_size=args.train_batch_size, | 591 | batch_size=args.train_batch_size, |
592 | tokenizer=tokenizer, | 592 | tokenizer=tokenizer, |
593 | instance_prompt=args.instance_prompt, | 593 | instance_prompt=args.instance_prompt, |
@@ -680,12 +680,12 @@ def main(): | |||
680 | 0, | 680 | 0, |
681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
682 | 682 | ||
683 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
684 | progress_bar.set_description("Global steps") | ||
685 | |||
686 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 683 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
687 | local_progress_bar.set_description("Steps") | 684 | local_progress_bar.set_description("Steps") |
688 | 685 | ||
686 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
687 | progress_bar.set_description("Global steps") | ||
688 | |||
689 | try: | 689 | try: |
690 | for epoch in range(args.num_train_epochs): | 690 | for epoch in range(args.num_train_epochs): |
691 | local_progress_bar.reset() | 691 | local_progress_bar.reset() |
@@ -733,14 +733,14 @@ def main(): | |||
733 | 733 | ||
734 | # Checks if the accelerator has performed an optimization step behind the scenes | 734 | # Checks if the accelerator has performed an optimization step behind the scenes |
735 | if accelerator.sync_gradients: | 735 | if accelerator.sync_gradients: |
736 | progress_bar.update(1) | ||
737 | local_progress_bar.update(1) | 736 | local_progress_bar.update(1) |
737 | progress_bar.update(1) | ||
738 | 738 | ||
739 | global_step += 1 | 739 | global_step += 1 |
740 | 740 | ||
741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
742 | progress_bar.clear() | ||
743 | local_progress_bar.clear() | 742 | local_progress_bar.clear() |
743 | progress_bar.clear() | ||
744 | 744 | ||
745 | checkpointer.save_samples( | 745 | checkpointer.save_samples( |
746 | "training", | 746 | "training", |
@@ -782,8 +782,8 @@ def main(): | |||
782 | val_loss += loss | 782 | val_loss += loss |
783 | 783 | ||
784 | if accelerator.sync_gradients: | 784 | if accelerator.sync_gradients: |
785 | progress_bar.update(1) | ||
786 | local_progress_bar.update(1) | 785 | local_progress_bar.update(1) |
786 | progress_bar.update(1) | ||
787 | 787 | ||
788 | logs = {"mode": "validation", "loss": loss} | 788 | logs = {"mode": "validation", "loss": loss} |
789 | local_progress_bar.set_postfix(**logs) | 789 | local_progress_bar.set_postfix(**logs) |
@@ -792,8 +792,8 @@ def main(): | |||
792 | 792 | ||
793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
794 | 794 | ||
795 | progress_bar.clear() | ||
796 | local_progress_bar.clear() | 795 | local_progress_bar.clear() |
796 | progress_bar.clear() | ||
797 | 797 | ||
798 | if min_val_loss > val_loss: | 798 | if min_val_loss > val_loss: |
799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |