diff options
| -rw-r--r-- | data/dreambooth/csv.py | 24 | ||||
| -rw-r--r-- | dreambooth.py | 22 |
2 files changed, 25 insertions, 21 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 04df4c6..e70c068 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -11,7 +11,7 @@ from torchvision import transforms | |||
| 11 | class CSVDataModule(pl.LightningDataModule): | 11 | class CSVDataModule(pl.LightningDataModule): |
| 12 | def __init__(self, | 12 | def __init__(self, |
| 13 | batch_size, | 13 | batch_size, |
| 14 | data_root, | 14 | data_file, |
| 15 | tokenizer, | 15 | tokenizer, |
| 16 | instance_prompt, | 16 | instance_prompt, |
| 17 | class_data_root=None, | 17 | class_data_root=None, |
| @@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 24 | collate_fn=None): | 24 | collate_fn=None): |
| 25 | super().__init__() | 25 | super().__init__() |
| 26 | 26 | ||
| 27 | self.data_root = data_root | 27 | self.data_file = Path(data_file) |
| 28 | |||
| 29 | if not self.data_file.is_file(): | ||
| 30 | raise ValueError("data_file must be a file") | ||
| 31 | |||
| 32 | self.data_root = self.data_file.parent | ||
| 28 | self.tokenizer = tokenizer | 33 | self.tokenizer = tokenizer |
| 29 | self.instance_prompt = instance_prompt | 34 | self.instance_prompt = instance_prompt |
| 30 | self.class_data_root = class_data_root | 35 | self.class_data_root = class_data_root |
| @@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 38 | self.batch_size = batch_size | 43 | self.batch_size = batch_size |
| 39 | 44 | ||
| 40 | def prepare_data(self): | 45 | def prepare_data(self): |
| 41 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 46 | metadata = pd.read_csv(self.data_file) |
| 42 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 47 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 43 | captions = [caption for caption in metadata['caption'].values] | 48 | captions = [caption for caption in metadata['caption'].values] |
| 44 | skips = [skip for skip in metadata['skip'].values] | 49 | skips = [skip for skip in metadata['skip'].values] |
| @@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 50 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 55 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
| 51 | 56 | ||
| 52 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
| 53 | class_data_root=self.class_data_root, | 58 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
| 54 | class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, | 59 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 55 | interpolation=self.interpolation, identifier=self.identifier, | 60 | center_crop=self.center_crop, repeats=self.repeats) |
| 56 | center_crop=self.center_crop) | ||
| 57 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, |
| 58 | class_data_root=self.class_data_root, | 62 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
| 59 | class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, | 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 60 | identifier=self.identifier, center_crop=self.center_crop) | 64 | center_crop=self.center_crop) |
| 61 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 62 | shuffle=True, collate_fn=self.collate_fn) | 66 | shuffle=True, collate_fn=self.collate_fn) |
| 63 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) |
diff --git a/dreambooth.py b/dreambooth.py index 89ed96a..45a0497 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -46,7 +46,7 @@ def parse_args(): | |||
| 46 | help="Pretrained tokenizer name or path if not the same as model_name", | 46 | help="Pretrained tokenizer name or path if not the same as model_name", |
| 47 | ) | 47 | ) |
| 48 | parser.add_argument( | 48 | parser.add_argument( |
| 49 | "--train_data_dir", | 49 | "--train_data_file", |
| 50 | type=str, | 50 | type=str, |
| 51 | default=None, | 51 | default=None, |
| 52 | help="A folder containing the training data." | 52 | help="A folder containing the training data." |
| @@ -269,8 +269,8 @@ def parse_args(): | |||
| 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| 270 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
| 271 | 271 | ||
| 272 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
| 273 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
| 274 | 274 | ||
| 275 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
| 276 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| @@ -587,7 +587,7 @@ def main(): | |||
| 587 | return batch | 587 | return batch |
| 588 | 588 | ||
| 589 | datamodule = CSVDataModule( | 589 | datamodule = CSVDataModule( |
| 590 | data_root=args.train_data_dir, | 590 | data_file=args.train_data_file, |
| 591 | batch_size=args.train_batch_size, | 591 | batch_size=args.train_batch_size, |
| 592 | tokenizer=tokenizer, | 592 | tokenizer=tokenizer, |
| 593 | instance_prompt=args.instance_prompt, | 593 | instance_prompt=args.instance_prompt, |
| @@ -680,12 +680,12 @@ def main(): | |||
| 680 | 0, | 680 | 0, |
| 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 681 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 682 | 682 | ||
| 683 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
| 684 | progress_bar.set_description("Global steps") | ||
| 685 | |||
| 686 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 683 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
| 687 | local_progress_bar.set_description("Steps") | 684 | local_progress_bar.set_description("Steps") |
| 688 | 685 | ||
| 686 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
| 687 | progress_bar.set_description("Global steps") | ||
| 688 | |||
| 689 | try: | 689 | try: |
| 690 | for epoch in range(args.num_train_epochs): | 690 | for epoch in range(args.num_train_epochs): |
| 691 | local_progress_bar.reset() | 691 | local_progress_bar.reset() |
| @@ -733,14 +733,14 @@ def main(): | |||
| 733 | 733 | ||
| 734 | # Checks if the accelerator has performed an optimization step behind the scenes | 734 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 735 | if accelerator.sync_gradients: | 735 | if accelerator.sync_gradients: |
| 736 | progress_bar.update(1) | ||
| 737 | local_progress_bar.update(1) | 736 | local_progress_bar.update(1) |
| 737 | progress_bar.update(1) | ||
| 738 | 738 | ||
| 739 | global_step += 1 | 739 | global_step += 1 |
| 740 | 740 | ||
| 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 741 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
| 742 | progress_bar.clear() | ||
| 743 | local_progress_bar.clear() | 742 | local_progress_bar.clear() |
| 743 | progress_bar.clear() | ||
| 744 | 744 | ||
| 745 | checkpointer.save_samples( | 745 | checkpointer.save_samples( |
| 746 | "training", | 746 | "training", |
| @@ -782,8 +782,8 @@ def main(): | |||
| 782 | val_loss += loss | 782 | val_loss += loss |
| 783 | 783 | ||
| 784 | if accelerator.sync_gradients: | 784 | if accelerator.sync_gradients: |
| 785 | progress_bar.update(1) | ||
| 786 | local_progress_bar.update(1) | 785 | local_progress_bar.update(1) |
| 786 | progress_bar.update(1) | ||
| 787 | 787 | ||
| 788 | logs = {"mode": "validation", "loss": loss} | 788 | logs = {"mode": "validation", "loss": loss} |
| 789 | local_progress_bar.set_postfix(**logs) | 789 | local_progress_bar.set_postfix(**logs) |
| @@ -792,8 +792,8 @@ def main(): | |||
| 792 | 792 | ||
| 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 793 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
| 794 | 794 | ||
| 795 | progress_bar.clear() | ||
| 796 | local_progress_bar.clear() | 795 | local_progress_bar.clear() |
| 796 | progress_bar.clear() | ||
| 797 | 797 | ||
| 798 | if min_val_loss > val_loss: | 798 | if min_val_loss > val_loss: |
| 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 799 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
