summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-15 10:26:07 +0100
committerVolpeon <git@volpeon.ink>2023-02-15 10:26:07 +0100
commit838e65823600611e281259b6f2d1f83a938bf7dc (patch)
tree4b71778f1e917e0f59b9338f541c317f51a4488b /data/csv.py
parentMade low-freq noise configurable (diff)
downloadtextual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.gz
textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.tar.bz2
textual-inversion-diff-838e65823600611e281259b6f2d1f83a938bf7dc.zip
Dataset: Repeat data to fill batch to batch_size
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py
index c5902ed..913268f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,6 +2,7 @@ import math
2import torch 2import torch
3import json 3import json
4from functools import partial 4from functools import partial
5import itertools
5from pathlib import Path 6from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 7from typing import NamedTuple, Optional, Union, Callable
7 8
@@ -407,6 +408,7 @@ class VlpnDataset(IterableDataset):
407 408
408 if len(bucket_items) == 0: 409 if len(bucket_items) == 0:
409 if len(batch) != 0: 410 if len(batch) != 0:
411 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
410 yield batch 412 yield batch
411 batch = [] 413 batch = []
412 414
@@ -446,4 +448,5 @@ class VlpnDataset(IterableDataset):
446 batch.append(example) 448 batch.append(example)
447 449
448 if len(batch) != 0: 450 if len(batch) != 0:
451 batch = list(itertools.islice(itertools.cycle(batch), batch_size))
449 yield batch 452 yield batch