From 403f525d0c6900cc6844c0d2f4ecb385fc131969 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 09:40:24 +0100 Subject: Fixed reproducibility, more consistant validation --- data/csv.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index af36d9e..e901ab4 100644 --- a/data/csv.py +++ b/data/csv.py @@ -59,7 +59,7 @@ class CSVDataModule(): center_crop: bool = False, template_key: str = "template", valid_set_size: Optional[int] = None, - generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, filter: Optional[Callable[[CSVDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 @@ -84,7 +84,7 @@ class CSVDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size - self.generator = generator + self.seed = seed self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers @@ -155,7 +155,11 @@ class CSVDataModule(): valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size - data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) + generator = torch.Generator(device="cpu") + if self.seed is not None: + generator = generator.manual_seed(self.seed) + + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) -- cgit v1.2.3-54-g00ecf