From 33e7d2ed37e32657ca94d92815043026c4cea7c0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 10 Jan 2023 09:22:02 +0100 Subject: Added arg to disable tag shuffling --- data/csv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index ed8e93d..9ad7dd6 100644 --- a/data/csv.py +++ b/data/csv.py @@ -122,6 +122,7 @@ class VlpnDataModule(): bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, + shuffle: bool = False, interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, @@ -150,6 +151,7 @@ class VlpnDataModule(): self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets self.dropout = dropout + self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -240,7 +242,7 @@ class VlpnDataModule(): bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, - num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, + num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) val_dataset = VlpnDataset( -- cgit v1.2.3-70-g09d2