diff options
author | Volpeon <git@volpeon.ink> | 2023-02-15 10:26:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-15 10:26:07 +0100 |
commit | 838e65823600611e281259b6f2d1f83a938bf7dc (patch) | |
tree | 4b71778f1e917e0f59b9338f541c317f51a4488b | |
parent | Made low-freq noise configurable (diff) | |
download | textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.gz textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.bz2 textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.zip |
Dataset: Repeat data to fill batch to batch_size
-rw-r--r-- | data/csv.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py index c5902ed..913268f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -2,6 +2,7 @@ import math | |||
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | from functools import partial | 4 | from functools import partial |
5 | import itertools | ||
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | from typing import NamedTuple, Optional, Union, Callable | 7 | from typing import NamedTuple, Optional, Union, Callable |
7 | 8 | ||
@@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset): | |||
407 | 408 | ||
408 | if len(bucket_items) == 0: | 409 | if len(bucket_items) == 0: |
409 | if len(batch) != 0: | 410 | if len(batch) != 0: |
411 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
410 | yield batch | 412 | yield batch |
411 | batch = [] | 413 | batch = [] |
412 | 414 | ||
@@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset): | |||
446 | batch.append(example) | 448 | batch.append(example) |
447 | 449 | ||
448 | if len(batch) != 0: | 450 | if len(batch) != 0: |
451 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
449 | yield batch | 452 | yield batch |