diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -269,18 +269,22 @@ class VlpnDataModule(): | |||
| 269 | 269 | ||
| 270 | num_images = len(items) | 270 | num_images = len(items) |
| 271 | 271 | ||
| 272 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 | 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| 273 | valid_set_size = max(valid_set_size, 1) | 273 | train_set_size = max(num_images - valid_set_size, 1) |
| 274 | train_set_size = num_images - valid_set_size | 274 | valid_set_size = num_images - train_set_size |
| 275 | 275 | ||
| 276 | generator = torch.Generator(device="cpu") | 276 | generator = torch.Generator(device="cpu") |
| 277 | if self.seed is not None: | 277 | if self.seed is not None: |
| 278 | generator = generator.manual_seed(self.seed) | 278 | generator = generator.manual_seed(self.seed) |
| 279 | 279 | ||
| 280 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 280 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 281 | |||
| 282 | if valid_set_size == 0: | ||
| 283 | data_train, data_val = items, [] | ||
| 284 | else: | ||
| 285 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
| 281 | 286 | ||
| 282 | self.data_train = self.pad_items(data_train, self.num_class_images) | 287 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 283 | self.data_val = self.pad_items(data_val) | ||
| 284 | 288 | ||
| 285 | train_dataset = VlpnDataset( | 289 | train_dataset = VlpnDataset( |
| 286 | self.data_train, self.tokenizer, | 290 | self.data_train, self.tokenizer, |
| @@ -291,26 +295,29 @@ class VlpnDataModule(): | |||
| 291 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 295 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
| 292 | ) | 296 | ) |
| 293 | 297 | ||
| 294 | val_dataset = VlpnDataset( | ||
| 295 | self.data_val, self.tokenizer, | ||
| 296 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 298 | repeat=self.valid_set_repeat, | ||
| 299 | batch_size=self.batch_size, generator=generator, | ||
| 300 | size=self.size, interpolation=self.interpolation, | ||
| 301 | ) | ||
| 302 | |||
| 303 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | ||
| 304 | |||
| 305 | self.train_dataloader = DataLoader( | 298 | self.train_dataloader = DataLoader( |
| 306 | train_dataset, | 299 | train_dataset, |
| 307 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 300 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 308 | ) | 301 | ) |
| 309 | 302 | ||
| 310 | self.val_dataloader = DataLoader( | 303 | if valid_set_size != 0: |
| 311 | val_dataset, | 304 | self.data_val = self.pad_items(data_val) |
| 312 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | |
| 313 | ) | 306 | val_dataset = VlpnDataset( |
| 307 | self.data_val, self.tokenizer, | ||
| 308 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 309 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 310 | repeat=self.valid_set_repeat, | ||
| 311 | batch_size=self.batch_size, generator=generator, | ||
| 312 | size=self.size, interpolation=self.interpolation, | ||
| 313 | ) | ||
| 314 | |||
| 315 | self.val_dataloader = DataLoader( | ||
| 316 | val_dataset, | ||
| 317 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
| 318 | ) | ||
| 319 | else: | ||
| 320 | self.val_dataloader = None | ||
| 314 | 321 | ||
| 315 | 322 | ||
| 316 | class VlpnDataset(IterableDataset): | 323 | class VlpnDataset(IterableDataset): |
