From fad870919737a19ea28f0c501f8139ce6a98b7fb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 15 Feb 2023 12:22:44 +0100 Subject: Better batch filling --- data/csv.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 4ebdc1e..480e9f2 100644 --- a/data/csv.py +++ b/data/csv.py @@ -2,7 +2,6 @@ import math import torch import json from functools import partial -import itertools from pathlib import Path from typing import NamedTuple, Optional, Union, Callable @@ -411,7 +410,9 @@ class VlpnDataset(IterableDataset): if len(bucket_items) == 0: if len(batch) != 0: if self.fill_batch: - batch = list(itertools.islice(itertools.cycle(batch), batch_size)) + fill_items = self.bucket_items[self.bucket_assignments == bucket] + fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) + batch += fill_items[fill_perm] yield batch batch = [] @@ -452,5 +453,7 @@ class VlpnDataset(IterableDataset): if len(batch) != 0: if self.fill_batch: - batch = list(itertools.islice(itertools.cycle(batch), batch_size)) + fill_items = self.bucket_items[self.bucket_assignments == bucket] + fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) + batch += fill_items[fill_perm] yield batch -- cgit v1.2.3-70-g09d2