diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py index ed8e93d..9ad7dd6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -122,6 +122,7 @@ class VlpnDataModule(): | |||
122 | bucket_max_pixels: Optional[int] = None, | 122 | bucket_max_pixels: Optional[int] = None, |
123 | progressive_buckets: bool = False, | 123 | progressive_buckets: bool = False, |
124 | dropout: float = 0, | 124 | dropout: float = 0, |
125 | shuffle: bool = False, | ||
125 | interpolation: str = "bicubic", | 126 | interpolation: str = "bicubic", |
126 | template_key: str = "template", | 127 | template_key: str = "template", |
127 | valid_set_size: Optional[int] = None, | 128 | valid_set_size: Optional[int] = None, |
@@ -150,6 +151,7 @@ class VlpnDataModule(): | |||
150 | self.bucket_max_pixels = bucket_max_pixels | 151 | self.bucket_max_pixels = bucket_max_pixels |
151 | self.progressive_buckets = progressive_buckets | 152 | self.progressive_buckets = progressive_buckets |
152 | self.dropout = dropout | 153 | self.dropout = dropout |
154 | self.shuffle = shuffle | ||
153 | self.template_key = template_key | 155 | self.template_key = template_key |
154 | self.interpolation = interpolation | 156 | self.interpolation = interpolation |
155 | self.valid_set_size = valid_set_size | 157 | self.valid_set_size = valid_set_size |
@@ -240,7 +242,7 @@ class VlpnDataModule(): | |||
240 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 242 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
241 | batch_size=self.batch_size, generator=generator, | 243 | batch_size=self.batch_size, generator=generator, |
242 | size=self.size, interpolation=self.interpolation, | 244 | size=self.size, interpolation=self.interpolation, |
243 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 245 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
244 | ) | 246 | ) |
245 | 247 | ||
246 | val_dataset = VlpnDataset( | 248 | val_dataset = VlpnDataset( |