summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-15 12:22:44 +0100
committerVolpeon <git@volpeon.ink>2023-02-15 12:22:44 +0100
commitfad870919737a19ea28f0c501f8139ce6a98b7fb (patch)
tree4137f98df036bace7e1b40563560c257218306ee /data/csv.py
parentBetter batch filling behavior (diff)
downloadtextual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.tar.gz
textual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.tar.bz2
textual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.zip
Better batch filling
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 4ebdc1e..480e9f2 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,7 +2,6 @@ import math
2import torch 2import torch
3import json 3import json
4from functools import partial 4from functools import partial
5import itertools
6from pathlib import Path 5from pathlib import Path
7from typing import NamedTuple, Optional, Union, Callable 6from typing import NamedTuple, Optional, Union, Callable
8 7
@@ -411,7 +410,9 @@ class VlpnDataset(IterableDataset):
411 if len(bucket_items) == 0: 410 if len(bucket_items) == 0:
412 if len(batch) != 0: 411 if len(batch) != 0:
413 if self.fill_batch: 412 if self.fill_batch:
414 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 413 fill_items = self.bucket_items[self.bucket_assignments == bucket]
414 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
415 batch += fill_items[fill_perm]
415 yield batch 416 yield batch
416 batch = [] 417 batch = []
417 418
@@ -452,5 +453,7 @@ class VlpnDataset(IterableDataset):
452 453
453 if len(batch) != 0: 454 if len(batch) != 0:
454 if self.fill_batch: 455 if self.fill_batch:
455 batch = list(itertools.islice(itertools.cycle(batch), batch_size)) 456 fill_items = self.bucket_items[self.bucket_assignments == bucket]
457 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
458 batch += fill_items[fill_perm]
456 yield batch 459 yield batch