From 89afcfda3f824cc44221e877182348f9b09687d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:31:55 +0100 Subject: Handle empty validation dataset --- data/csv.py | 47 +++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -269,18 +269,22 @@ class VlpnDataModule(): num_images = len(items) - valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 - valid_set_size = max(valid_set_size, 1) - train_set_size = num_images - valid_set_size + valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 + train_set_size = max(num_images - valid_set_size, 1) + valid_set_size = num_images - train_set_size generator = torch.Generator(device="cpu") if self.seed is not None: generator = generator.manual_seed(self.seed) - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) + + if valid_set_size == 0: + data_train, data_val = items, [] + else: + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) - self.data_val = self.pad_items(data_val) train_dataset = VlpnDataset( self.data_train, self.tokenizer, @@ -291,26 +295,29 @@ class VlpnDataModule(): num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) - val_dataset = VlpnDataset( - self.data_val, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=True, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - repeat=self.valid_set_repeat, - batch_size=self.batch_size, generator=generator, - size=self.size, interpolation=self.interpolation, - ) - - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) - self.train_dataloader = DataLoader( train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) - self.val_dataloader = DataLoader( - val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ - ) + if valid_set_size != 0: + self.data_val = self.pad_items(data_val) + + val_dataset = VlpnDataset( + self.data_val, self.tokenizer, + num_buckets=self.num_buckets, progressive_buckets=True, + bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, + repeat=self.valid_set_repeat, + batch_size=self.batch_size, generator=generator, + size=self.size, interpolation=self.interpolation, + ) + + self.val_dataloader = DataLoader( + val_dataset, + batch_size=None, pin_memory=True, collate_fn=collate_fn_ + ) + else: + self.val_dataloader = None class VlpnDataset(IterableDataset): -- cgit v1.2.3-70-g09d2