diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-15 10:26:07 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-15 10:26:07 +0100 |
| commit | 838e65823600611e281259b6f2d1f83a938bf7dc (patch) | |
| tree | 4b71778f1e917e0f59b9338f541c317f51a4488b | |
| parent | Made low-freq noise configurable (diff) | |
| download | textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.gz textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.bz2 textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.zip | |
Dataset: Repeat data to fill batch to batch_size
| -rw-r--r-- | data/csv.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py index c5902ed..913268f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -2,6 +2,7 @@ import math | |||
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | import itertools | ||
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | from typing import NamedTuple, Optional, Union, Callable | 7 | from typing import NamedTuple, Optional, Union, Callable |
| 7 | 8 | ||
| @@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset): | |||
| 407 | 408 | ||
| 408 | if len(bucket_items) == 0: | 409 | if len(bucket_items) == 0: |
| 409 | if len(batch) != 0: | 410 | if len(batch) != 0: |
| 411 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
| 410 | yield batch | 412 | yield batch |
| 411 | batch = [] | 413 | batch = [] |
| 412 | 414 | ||
| @@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset): | |||
| 446 | batch.append(example) | 448 | batch.append(example) |
| 447 | 449 | ||
| 448 | if len(batch) != 0: | 450 | if len(batch) != 0: |
| 451 | batch = list(itertools.islice(itertools.cycle(batch), batch_size)) | ||
| 449 | yield batch | 452 | yield batch |
