summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py
index c5902ed..913268f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,6 +2,7 @@ import math
2import torch 2import torch
3import json 3import json
4from functools import partial 4from functools import partial
5import itertools
5from pathlib import Path 6from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 7from typing import NamedTuple, Optional, Union, Callable
7 8
@@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset):
407 408
408 if len(bucket_items) == 0: 409 if len(bucket_items) == 0:
409 if len(batch) != 0: 410 if len(batch) != 0:
411 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
410 yield batch 412 yield batch
411 batch = [] 413 batch = []
412 414
@@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset):
446 batch.append(example) 448 batch.append(example)
447 449
448 if len(batch) != 0: 450 if len(batch) != 0:
451 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
449 yield batch 452 yield batch