diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
38 | center_crop: bool = False, | 38 | center_crop: bool = False, |
39 | valid_set_size: Optional[int] = None, | 39 | valid_set_size: Optional[int] = None, |
40 | generator: Optional[torch.Generator] = None, | 40 | generator: Optional[torch.Generator] = None, |
41 | collate_fn=None | 41 | collate_fn=None, |
42 | num_workers: int = 0 | ||
42 | ): | 43 | ): |
43 | super().__init__() | 44 | super().__init__() |
44 | 45 | ||
@@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
62 | self.valid_set_size = valid_set_size | 63 | self.valid_set_size = valid_set_size |
63 | self.generator = generator | 64 | self.generator = generator |
64 | self.collate_fn = collate_fn | 65 | self.collate_fn = collate_fn |
66 | self.num_workers = num_workers | ||
65 | self.batch_size = batch_size | 67 | self.batch_size = batch_size |
66 | 68 | ||
67 | def prepare_subdata(self, template, data, num_class_images=1): | 69 | def prepare_subdata(self, template, data, num_class_images=1): |
@@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
113 | size=self.size, interpolation=self.interpolation, | 115 | size=self.size, interpolation=self.interpolation, |
114 | center_crop=self.center_crop) | 116 | center_crop=self.center_crop) |
115 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 117 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
116 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 118 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, |
119 | num_workers=self.num_workers) | ||
117 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 120 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
118 | pin_memory=True, collate_fn=self.collate_fn) | 121 | pin_memory=True, collate_fn=self.collate_fn, |
122 | num_workers=self.num_workers) | ||
119 | 123 | ||
120 | def train_dataloader(self): | 124 | def train_dataloader(self): |
121 | return self.train_dataloader_ | 125 | return self.train_dataloader_ |