diff options
Diffstat (limited to 'data/csv.py')
| -rw-r--r-- | data/csv.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py index c5902ed..913268f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -2,6 +2,7 @@ import math | |||
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | import itertools | ||
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | from typing import NamedTuple, Optional, Union, Callable | 7 | from typing import NamedTuple, Optional, Union, Callable |
| 7 | 8 | ||
| @@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset): | |||
| 407 | 408 | ||
| 408 | if len(bucket_items) == 0: | 409 | if len(bucket_items) == 0: |
| 409 | if len(batch) != 0: | 410 | if len(batch) != 0: |
| 411 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
| 410 | yield batch | 412 | yield batch |
| 411 | batch = [] | 413 | batch = [] |
| 412 | 414 | ||
| @@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset): | |||
| 446 | batch.append(example) | 448 | batch.append(example) |
| 447 | 449 | ||
| 448 | if len(batch) != 0: | 450 | if len(batch) != 0: |
| 451 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
| 449 | yield batch | 452 | yield batch |
