diff options
author | Volpeon <git@volpeon.ink> | 2023-02-15 12:22:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-15 12:22:44 +0100 |
commit | fad870919737a19ea28f0c501f8139ce6a98b7fb (patch) | |
tree | 4137f98df036bace7e1b40563560c257218306ee /data | |
parent | Better batch filling behavior (diff) | |
download | textual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.tar.gz textual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.tar.bz2 textual-inversion-diff-fad870919737a19ea28f0c501f8139ce6a98b7fb.zip |
Better batch filling
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 4ebdc1e..480e9f2 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -2,7 +2,6 @@ import math | |||
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | from functools import partial | 4 | from functools import partial |
5 | import itertools | ||
6 | from pathlib import Path | 5 | from pathlib import Path |
7 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
8 | 7 | ||
@@ -411,7 +410,9 @@ class VlpnDataset(IterableDataset): | |||
411 | if len(bucket_items) == 0: | 410 | if len(bucket_items) == 0: |
412 | if len(batch) != 0: | 411 | if len(batch) != 0: |
413 | if self.fill_batch: | 412 | if self.fill_batch: |
414 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 413 | fill_items = self.bucket_items[self.bucket_assignments == bucket] |
414 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
415 | batch += fill_items[fill_perm] | ||
415 | yield batch | 416 | yield batch |
416 | batch = [] | 417 | batch = [] |
417 | 418 | ||
@@ -452,5 +453,7 @@ class VlpnDataset(IterableDataset): | |||
452 | 453 | ||
453 | if len(batch) != 0: | 454 | if len(batch) != 0: |
454 | if self.fill_batch: | 455 | if self.fill_batch: |
455 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | 456 | fill_items = self.bucket_items[self.bucket_assignments == bucket] |
457 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
458 | batch += fill_items[fill_perm] | ||
456 | yield batch | 459 | yield batch |