diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index af36d9e..e901ab4 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -59,7 +59,7 @@ class CSVDataModule(): | |||
| 59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
| 60 | template_key: str = "template", | 60 | template_key: str = "template", |
| 61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
| 62 | generator: Optional[torch.Generator] = None, | 62 | seed: Optional[int] = None, |
| 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
| 64 | collate_fn=None, | 64 | collate_fn=None, |
| 65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
| @@ -84,7 +84,7 @@ class CSVDataModule(): | |||
| 84 | self.template_key = template_key | 84 | self.template_key = template_key |
| 85 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
| 86 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
| 87 | self.generator = generator | 87 | self.seed = seed |
| 88 | self.filter = filter | 88 | self.filter = filter |
| 89 | self.collate_fn = collate_fn | 89 | self.collate_fn = collate_fn |
| 90 | self.num_workers = num_workers | 90 | self.num_workers = num_workers |
| @@ -155,7 +155,11 @@ class CSVDataModule(): | |||
| 155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
| 156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
| 157 | 157 | ||
| 158 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 158 | generator = torch.Generator(device="cpu") |
| 159 | if self.seed is not None: | ||
| 160 | generator = generator.manual_seed(self.seed) | ||
| 161 | |||
| 162 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
| 159 | 163 | ||
| 160 | self.data_train = self.pad_items(data_train, self.num_class_images) | 164 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 161 | self.data_val = self.pad_items(data_val) | 165 | self.data_val = self.pad_items(data_val) |
