diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 38 | center_crop: bool = False, | 38 | center_crop: bool = False, |
| 39 | valid_set_size: Optional[int] = None, | 39 | valid_set_size: Optional[int] = None, |
| 40 | generator: Optional[torch.Generator] = None, | 40 | generator: Optional[torch.Generator] = None, |
| 41 | collate_fn=None | 41 | collate_fn=None, |
| 42 | num_workers: int = 0 | ||
| 42 | ): | 43 | ): |
| 43 | super().__init__() | 44 | super().__init__() |
| 44 | 45 | ||
| @@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 62 | self.valid_set_size = valid_set_size | 63 | self.valid_set_size = valid_set_size |
| 63 | self.generator = generator | 64 | self.generator = generator |
| 64 | self.collate_fn = collate_fn | 65 | self.collate_fn = collate_fn |
| 66 | self.num_workers = num_workers | ||
| 65 | self.batch_size = batch_size | 67 | self.batch_size = batch_size |
| 66 | 68 | ||
| 67 | def prepare_subdata(self, template, data, num_class_images=1): | 69 | def prepare_subdata(self, template, data, num_class_images=1): |
| @@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 113 | size=self.size, interpolation=self.interpolation, | 115 | size=self.size, interpolation=self.interpolation, |
| 114 | center_crop=self.center_crop) | 116 | center_crop=self.center_crop) |
| 115 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 117 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 116 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 118 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, |
| 119 | num_workers=self.num_workers) | ||
| 117 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 120 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 118 | pin_memory=True, collate_fn=self.collate_fn) | 121 | pin_memory=True, collate_fn=self.collate_fn, |
| 122 | num_workers=self.num_workers) | ||
| 119 | 123 | ||
| 120 | def train_dataloader(self): | 124 | def train_dataloader(self): |
| 121 | return self.train_dataloader_ | 125 | return self.train_dataloader_ |
