diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index af36d9e..e901ab4 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -59,7 +59,7 @@ class CSVDataModule(): | |||
59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | seed: Optional[int] = None, |
63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
64 | collate_fn=None, | 64 | collate_fn=None, |
65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
@@ -84,7 +84,7 @@ class CSVDataModule(): | |||
84 | self.template_key = template_key | 84 | self.template_key = template_key |
85 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
86 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
87 | self.generator = generator | 87 | self.seed = seed |
88 | self.filter = filter | 88 | self.filter = filter |
89 | self.collate_fn = collate_fn | 89 | self.collate_fn = collate_fn |
90 | self.num_workers = num_workers | 90 | self.num_workers = num_workers |
@@ -155,7 +155,11 @@ class CSVDataModule(): | |||
155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
157 | 157 | ||
158 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 158 | generator = torch.Generator(device="cpu") |
159 | if self.seed is not None: | ||
160 | generator = generator.manual_seed(self.seed) | ||
161 | |||
162 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
159 | 163 | ||
160 | self.data_train = self.pad_items(data_train, self.num_class_images) | 164 | self.data_train = self.pad_items(data_train, self.num_class_images) |
161 | self.data_val = self.pad_items(data_val) | 165 | self.data_val = self.pad_items(data_val) |