diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 9 | 
1 files changed, 6 insertions, 3 deletions
| diff --git a/data/csv.py b/data/csv.py index 4ebdc1e..480e9f2 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -2,7 +2,6 @@ import math | |||
| 2 | import torch | 2 | import torch | 
| 3 | import json | 3 | import json | 
| 4 | from functools import partial | 4 | from functools import partial | 
| 5 | import itertools | ||
| 6 | from pathlib import Path | 5 | from pathlib import Path | 
| 7 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable | 
| 8 | 7 | ||
| @@ -411,7 +410,9 @@ class VlpnDataset(IterableDataset): | |||
| 411 | if len(bucket_items) == 0: | 410 | if len(bucket_items) == 0: | 
| 412 | if len(batch) != 0: | 411 | if len(batch) != 0: | 
| 413 | if self.fill_batch: | 412 | if self.fill_batch: | 
| 414 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 413 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | 
| 414 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
| 415 | batch += fill_items[fill_perm] | ||
| 415 | yield batch | 416 | yield batch | 
| 416 | batch = [] | 417 | batch = [] | 
| 417 | 418 | ||
| @@ -452,5 +453,7 @@ class VlpnDataset(IterableDataset): | |||
| 452 | 453 | ||
| 453 | if len(batch) != 0: | 454 | if len(batch) != 0: | 
| 454 | if self.fill_batch: | 455 | if self.fill_batch: | 
| 455 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 456 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | 
| 457 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
| 458 | batch += fill_items[fill_perm] | ||
| 456 | yield batch | 459 | yield batch | 
