summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index 2f0a392..584a40c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -125,6 +125,7 @@ class VlpnDataModule():
125 interpolation: str = "bicubic", 125 interpolation: str = "bicubic",
126 template_key: str = "template", 126 template_key: str = "template",
127 valid_set_size: Optional[int] = None, 127 valid_set_size: Optional[int] = None,
128 valid_set_repeat: int = 1,
128 seed: Optional[int] = None, 129 seed: Optional[int] = None,
129 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 130 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
130 collate_fn=None, 131 collate_fn=None,
@@ -152,6 +153,7 @@ class VlpnDataModule():
152 self.template_key = template_key 153 self.template_key = template_key
153 self.interpolation = interpolation 154 self.interpolation = interpolation
154 self.valid_set_size = valid_set_size 155 self.valid_set_size = valid_set_size
156 self.valid_set_repeat = valid_set_repeat
155 self.seed = seed 157 self.seed = seed
156 self.filter = filter 158 self.filter = filter
157 self.collate_fn = collate_fn 159 self.collate_fn = collate_fn
@@ -243,6 +245,7 @@ class VlpnDataModule():
243 245
244 val_dataset = VlpnDataset( 246 val_dataset = VlpnDataset(
245 self.data_val, self.prompt_processor, 247 self.data_val, self.prompt_processor,
248 repeat=self.valid_set_repeat,
246 batch_size=self.batch_size, generator=generator, 249 batch_size=self.batch_size, generator=generator,
247 size=self.size, interpolation=self.interpolation, 250 size=self.size, interpolation=self.interpolation,
248 ) 251 )
@@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset):
267 bucket_step_size: int = 64, 270 bucket_step_size: int = 64,
268 bucket_max_pixels: Optional[int] = None, 271 bucket_max_pixels: Optional[int] = None,
269 progressive_buckets: bool = False, 272 progressive_buckets: bool = False,
273 repeat: int = 1,
270 batch_size: int = 1, 274 batch_size: int = 1,
271 num_class_images: int = 0, 275 num_class_images: int = 0,
272 size: int = 768, 276 size: int = 768,
@@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset):
275 interpolation: str = "bicubic", 279 interpolation: str = "bicubic",
276 generator: Optional[torch.Generator] = None, 280 generator: Optional[torch.Generator] = None,
277 ): 281 ):
278 self.items = items 282 self.items = items * repeat
279 self.batch_size = batch_size 283 self.batch_size = batch_size
280 284
281 self.prompt_processor = prompt_processor 285 self.prompt_processor = prompt_processor