From 838e65823600611e281259b6f2d1f83a938bf7dc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 15 Feb 2023 10:26:07 +0100 Subject: Dataset: Repeat data to fill batch to batch_size --- data/csv.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index c5902ed..913268f 100644 --- a/data/csv.py +++ b/data/csv.py @@ -2,6 +2,7 @@ import math import torch import json from functools import partial +import itertools from pathlib import Path from typing import NamedTuple, Optional, Union, Callable @@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset): if len(bucket_items) == 0: if len(batch) != 0: + batch = list(itertools.islice(itertools.cycle(batch), batch_size)) yield batch batch = [] @@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset): batch.append(example) if len(batch) != 0: + batch = list(itertools.islice(itertools.cycle(batch), batch_size)) yield batch -- cgit v1.2.3-54-g00ecf