summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 04df4c6..e70c068 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -11,7 +11,7 @@ from torchvision import transforms
11class CSVDataModule(pl.LightningDataModule): 11class CSVDataModule(pl.LightningDataModule):
12 def __init__(self, 12 def __init__(self,
13 batch_size, 13 batch_size,
14 data_root, 14 data_file,
15 tokenizer, 15 tokenizer,
16 instance_prompt, 16 instance_prompt,
17 class_data_root=None, 17 class_data_root=None,
@@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule):
24 collate_fn=None): 24 collate_fn=None):
25 super().__init__() 25 super().__init__()
26 26
27 self.data_root = data_root 27 self.data_file = Path(data_file)
28
29 if not self.data_file.is_file():
30 raise ValueError("data_file must be a file")
31
32 self.data_root = self.data_file.parent
28 self.tokenizer = tokenizer 33 self.tokenizer = tokenizer
29 self.instance_prompt = instance_prompt 34 self.instance_prompt = instance_prompt
30 self.class_data_root = class_data_root 35 self.class_data_root = class_data_root
@@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule):
38 self.batch_size = batch_size 43 self.batch_size = batch_size
39 44
40 def prepare_data(self): 45 def prepare_data(self):
41 metadata = pd.read_csv(f'{self.data_root}/list.csv') 46 metadata = pd.read_csv(self.data_file)
42 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 47 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
43 captions = [caption for caption in metadata['caption'].values] 48 captions = [caption for caption in metadata['caption'].values]
44 skips = [skip for skip in metadata['skip'].values] 49 skips = [skip for skip in metadata['skip'].values]
@@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule):
50 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 55 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
51 56
52 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 57 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
53 class_data_root=self.class_data_root, 58 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
54 class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, 59 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
55 interpolation=self.interpolation, identifier=self.identifier, 60 center_crop=self.center_crop, repeats=self.repeats)
56 center_crop=self.center_crop)
57 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 61 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root, 62 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
59 class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, 63 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
60 identifier=self.identifier, center_crop=self.center_crop) 64 center_crop=self.center_crop)
61 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
62 shuffle=True, collate_fn=self.collate_fn) 66 shuffle=True, collate_fn=self.collate_fn)
63 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)