diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
commit | 9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f (patch) | |
tree | b365457b40a54fed792f2e3e9d776389a5c9017f /data | |
parent | Handle empty validation dataset (diff) | |
download | textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.gz textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.bz2 textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.zip |
Pad dataset if len(items) < batch_size
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -267,6 +267,9 @@ class VlpnDataModule(): | |||
267 | items = self.prepare_items(template, expansions, items) | 267 | items = self.prepare_items(template, expansions, items) |
268 | items = self.filter_items(items) | 268 | items = self.filter_items(items) |
269 | 269 | ||
270 | if (len(items) < self.batch_size): | ||
271 | items = (items * self.batch_size)[:self.batch_size] | ||
272 | |||
270 | num_images = len(items) | 273 | num_images = len(items) |
271 | 274 | ||
272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |