summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 4ebdc1e..480e9f2 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,7 +2,6 @@ import math
2import torch 2import torch
3import json 3import json
4from functools import partial 4from functools import partial
5import itertools
6from pathlib import Path 5from pathlib import Path
7from typing import NamedTuple, Optional, Union, Callable 6from typing import NamedTuple, Optional, Union, Callable
8 7
@@ -411,7 +410,9 @@ class VlpnDataset(IterableDataset):
411 if len(bucket_items) == 0: 410 if len(bucket_items) == 0:
412 if len(batch) != 0: 411 if len(batch) != 0:
413 if self.fill_batch: 412 if self.fill_batch:
414 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 413 fill_items = self.bucket_items[self.bucket_assignments == bucket]
414 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
415 batch += fill_items[fill_perm]
415 yield batch 416 yield batch
416 batch = [] 417 batch = []
417 418
@@ -452,5 +453,7 @@ class VlpnDataset(IterableDataset):
452 453
453 if len(batch) != 0: 454 if len(batch) != 0:
454 if self.fill_batch: 455 if self.fill_batch:
455 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 456 fill_items = self.bucket_items[self.bucket_assignments == bucket]
457 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
458 batch += fill_items[fill_perm]
456 yield batch 459 yield batch