summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index f9b5e39..6bd7f9b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule):
38 center_crop: bool = False, 38 center_crop: bool = False,
39 valid_set_size: Optional[int] = None, 39 valid_set_size: Optional[int] = None,
40 generator: Optional[torch.Generator] = None, 40 generator: Optional[torch.Generator] = None,
41 collate_fn=None 41 collate_fn=None,
42 num_workers: int = 0
42 ): 43 ):
43 super().__init__() 44 super().__init__()
44 45
@@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule):
62 self.valid_set_size = valid_set_size 63 self.valid_set_size = valid_set_size
63 self.generator = generator 64 self.generator = generator
64 self.collate_fn = collate_fn 65 self.collate_fn = collate_fn
66 self.num_workers = num_workers
65 self.batch_size = batch_size 67 self.batch_size = batch_size
66 68
67 def prepare_subdata(self, template, data, num_class_images=1): 69 def prepare_subdata(self, template, data, num_class_images=1):
@@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule):
113 size=self.size, interpolation=self.interpolation, 115 size=self.size, interpolation=self.interpolation,
114 center_crop=self.center_crop) 116 center_crop=self.center_crop)
115 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 117 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
116 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 118 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
119 num_workers=self.num_workers)
117 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 120 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
118 pin_memory=True, collate_fn=self.collate_fn) 121 pin_memory=True, collate_fn=self.collate_fn,
122 num_workers=self.num_workers)
119 123
120 def train_dataloader(self): 124 def train_dataloader(self):
121 return self.train_dataloader_ 125 return self.train_dataloader_