summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:51:02 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:51:02 +0100
commit9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f (patch)
treeb365457b40a54fed792f2e3e9d776389a5c9017f /data/csv.py
parentHandle empty validation dataset (diff)
downloadtextual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.gz
textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.bz2
textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.zip
Pad dataset if len(items) < batch_size
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py
index 968af8d..dec66d7 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -267,6 +267,9 @@ class VlpnDataModule():
267 items = self.prepare_items(template, expansions, items) 267 items = self.prepare_items(template, expansions, items)
268 items = self.filter_items(items) 268 items = self.filter_items(items)
269 269
270 if (len(items) < self.batch_size):
271 items = (items * self.batch_size)[:self.batch_size]
272
270 num_images = len(items) 273 num_images = len(items)
271 274
272 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 275 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10