diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 04df4c6..e70c068 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -11,7 +11,7 @@ from torchvision import transforms | |||
11 | class CSVDataModule(pl.LightningDataModule): | 11 | class CSVDataModule(pl.LightningDataModule): |
12 | def __init__(self, | 12 | def __init__(self, |
13 | batch_size, | 13 | batch_size, |
14 | data_root, | 14 | data_file, |
15 | tokenizer, | 15 | tokenizer, |
16 | instance_prompt, | 16 | instance_prompt, |
17 | class_data_root=None, | 17 | class_data_root=None, |
@@ -24,7 +24,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
24 | collate_fn=None): | 24 | collate_fn=None): |
25 | super().__init__() | 25 | super().__init__() |
26 | 26 | ||
27 | self.data_root = data_root | 27 | self.data_file = Path(data_file) |
28 | |||
29 | if not self.data_file.is_file(): | ||
30 | raise ValueError("data_file must be a file") | ||
31 | |||
32 | self.data_root = self.data_file.parent | ||
28 | self.tokenizer = tokenizer | 33 | self.tokenizer = tokenizer |
29 | self.instance_prompt = instance_prompt | 34 | self.instance_prompt = instance_prompt |
30 | self.class_data_root = class_data_root | 35 | self.class_data_root = class_data_root |
@@ -38,7 +43,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
38 | self.batch_size = batch_size | 43 | self.batch_size = batch_size |
39 | 44 | ||
40 | def prepare_data(self): | 45 | def prepare_data(self): |
41 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 46 | metadata = pd.read_csv(self.data_file) |
42 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 47 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
43 | captions = [caption for caption in metadata['caption'].values] | 48 | captions = [caption for caption in metadata['caption'].values] |
44 | skips = [skip for skip in metadata['skip'].values] | 49 | skips = [skip for skip in metadata['skip'].values] |
@@ -50,14 +55,13 @@ class CSVDataModule(pl.LightningDataModule): | |||
50 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 55 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
51 | 56 | ||
52 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
53 | class_data_root=self.class_data_root, | 58 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
54 | class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, | 59 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
55 | interpolation=self.interpolation, identifier=self.identifier, | 60 | center_crop=self.center_crop, repeats=self.repeats) |
56 | center_crop=self.center_crop) | ||
57 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, |
58 | class_data_root=self.class_data_root, | 62 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
59 | class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, | 63 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
60 | identifier=self.identifier, center_crop=self.center_crop) | 64 | center_crop=self.center_crop) |
61 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 65 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
62 | shuffle=True, collate_fn=self.collate_fn) | 66 | shuffle=True, collate_fn=self.collate_fn) |
63 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 67 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) |