diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -269,18 +269,22 @@ class VlpnDataModule(): | |||
269 | 269 | ||
270 | num_images = len(items) | 270 | num_images = len(items) |
271 | 271 | ||
272 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 | 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
273 | valid_set_size = max(valid_set_size, 1) | 273 | train_set_size = max(num_images - valid_set_size, 1) |
274 | train_set_size = num_images - valid_set_size | 274 | valid_set_size = num_images - train_set_size |
275 | 275 | ||
276 | generator = torch.Generator(device="cpu") | 276 | generator = torch.Generator(device="cpu") |
277 | if self.seed is not None: | 277 | if self.seed is not None: |
278 | generator = generator.manual_seed(self.seed) | 278 | generator = generator.manual_seed(self.seed) |
279 | 279 | ||
280 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 280 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
281 | |||
282 | if valid_set_size == 0: | ||
283 | data_train, data_val = items, [] | ||
284 | else: | ||
285 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
281 | 286 | ||
282 | self.data_train = self.pad_items(data_train, self.num_class_images) | 287 | self.data_train = self.pad_items(data_train, self.num_class_images) |
283 | self.data_val = self.pad_items(data_val) | ||
284 | 288 | ||
285 | train_dataset = VlpnDataset( | 289 | train_dataset = VlpnDataset( |
286 | self.data_train, self.tokenizer, | 290 | self.data_train, self.tokenizer, |
@@ -291,26 +295,29 @@ class VlpnDataModule(): | |||
291 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 295 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
292 | ) | 296 | ) |
293 | 297 | ||
294 | val_dataset = VlpnDataset( | ||
295 | self.data_val, self.tokenizer, | ||
296 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
298 | repeat=self.valid_set_repeat, | ||
299 | batch_size=self.batch_size, generator=generator, | ||
300 | size=self.size, interpolation=self.interpolation, | ||
301 | ) | ||
302 | |||
303 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | ||
304 | |||
305 | self.train_dataloader = DataLoader( | 298 | self.train_dataloader = DataLoader( |
306 | train_dataset, | 299 | train_dataset, |
307 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 300 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
308 | ) | 301 | ) |
309 | 302 | ||
310 | self.val_dataloader = DataLoader( | 303 | if valid_set_size != 0: |
311 | val_dataset, | 304 | self.data_val = self.pad_items(data_val) |
312 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | |
313 | ) | 306 | val_dataset = VlpnDataset( |
307 | self.data_val, self.tokenizer, | ||
308 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
309 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
310 | repeat=self.valid_set_repeat, | ||
311 | batch_size=self.batch_size, generator=generator, | ||
312 | size=self.size, interpolation=self.interpolation, | ||
313 | ) | ||
314 | |||
315 | self.val_dataloader = DataLoader( | ||
316 | val_dataset, | ||
317 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
318 | ) | ||
319 | else: | ||
320 | self.val_dataloader = None | ||
314 | 321 | ||
315 | 322 | ||
316 | class VlpnDataset(IterableDataset): | 323 | class VlpnDataset(IterableDataset): |