diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py index 2f0a392..584a40c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -125,6 +125,7 @@ class VlpnDataModule(): | |||
| 125 | interpolation: str = "bicubic", | 125 | interpolation: str = "bicubic", |
| 126 | template_key: str = "template", | 126 | template_key: str = "template", |
| 127 | valid_set_size: Optional[int] = None, | 127 | valid_set_size: Optional[int] = None, |
| 128 | valid_set_repeat: int = 1, | ||
| 128 | seed: Optional[int] = None, | 129 | seed: Optional[int] = None, |
| 129 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 130 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 130 | collate_fn=None, | 131 | collate_fn=None, |
| @@ -152,6 +153,7 @@ class VlpnDataModule(): | |||
| 152 | self.template_key = template_key | 153 | self.template_key = template_key |
| 153 | self.interpolation = interpolation | 154 | self.interpolation = interpolation |
| 154 | self.valid_set_size = valid_set_size | 155 | self.valid_set_size = valid_set_size |
| 156 | self.valid_set_repeat = valid_set_repeat | ||
| 155 | self.seed = seed | 157 | self.seed = seed |
| 156 | self.filter = filter | 158 | self.filter = filter |
| 157 | self.collate_fn = collate_fn | 159 | self.collate_fn = collate_fn |
| @@ -243,6 +245,7 @@ class VlpnDataModule(): | |||
| 243 | 245 | ||
| 244 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
| 245 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
| 248 | repeat=self.valid_set_repeat, | ||
| 246 | batch_size=self.batch_size, generator=generator, | 249 | batch_size=self.batch_size, generator=generator, |
| 247 | size=self.size, interpolation=self.interpolation, | 250 | size=self.size, interpolation=self.interpolation, |
| 248 | ) | 251 | ) |
| @@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset): | |||
| 267 | bucket_step_size: int = 64, | 270 | bucket_step_size: int = 64, |
| 268 | bucket_max_pixels: Optional[int] = None, | 271 | bucket_max_pixels: Optional[int] = None, |
| 269 | progressive_buckets: bool = False, | 272 | progressive_buckets: bool = False, |
| 273 | repeat: int = 1, | ||
| 270 | batch_size: int = 1, | 274 | batch_size: int = 1, |
| 271 | num_class_images: int = 0, | 275 | num_class_images: int = 0, |
| 272 | size: int = 768, | 276 | size: int = 768, |
| @@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset): | |||
| 275 | interpolation: str = "bicubic", | 279 | interpolation: str = "bicubic", |
| 276 | generator: Optional[torch.Generator] = None, | 280 | generator: Optional[torch.Generator] = None, |
| 277 | ): | 281 | ): |
| 278 | self.items = items | 282 | self.items = items * repeat |
| 279 | self.batch_size = batch_size | 283 | self.batch_size = batch_size |
| 280 | 284 | ||
| 281 | self.prompt_processor = prompt_processor | 285 | self.prompt_processor = prompt_processor |
