From 9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:51:02 +0100 Subject: Pad dataset if len(items) < batch_size --- data/csv.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py @@ -267,6 +267,9 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) + if (len(items) < self.batch_size): + items = (items * self.batch_size)[:self.batch_size] + num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 -- cgit v1.2.3-54-g00ecf