From 5210c15fd812328f8f0d7c95d3ed4ec41bdf6444 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 18:10:12 +0200 Subject: Supply dataset CSV file instead of dir with hardcoded CSV filename --- data/dreambooth/csv.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) (limited to 'data') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 04df4c6..e70c068 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -11,7 +11,7 @@ from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, - data_root, + data_file, tokenizer, instance_prompt, class_data_root=None, @@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule): collate_fn=None): super().__init__() - self.data_root = data_root + self.data_file = Path(data_file) + + if not self.data_file.is_file(): + raise ValueError("data_file must be a file") + + self.data_root = self.data_file.parent self.tokenizer = tokenizer self.instance_prompt = instance_prompt self.class_data_root = class_data_root @@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule): self.batch_size = batch_size def prepare_data(self): - metadata = pd.read_csv(f'{self.data_root}/list.csv') + metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] captions = [caption for caption in metadata['caption'].values] skips = [skip for skip in metadata['skip'].values] @@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule): self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, - class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, - interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop) + class_data_root=self.class_data_root, class_prompt=self.class_prompt, + size=self.size, interpolation=self.interpolation, identifier=self.identifier, + center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, - class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, - identifier=self.identifier, center_crop=self.center_crop) + class_data_root=self.class_data_root, class_prompt=self.class_prompt, + size=self.size, interpolation=self.interpolation, identifier=self.identifier, + center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) -- cgit v1.2.3-70-g09d2