diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
| commit | 9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f (patch) | |
| tree | b365457b40a54fed792f2e3e9d776389a5c9017f /data | |
| parent | Handle empty validation dataset (diff) | |
| download | textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.gz textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.bz2 textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.zip | |
Pad dataset if len(items) < batch_size
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -267,6 +267,9 @@ class VlpnDataModule(): | |||
| 267 | items = self.prepare_items(template, expansions, items) | 267 | items = self.prepare_items(template, expansions, items) |
| 268 | items = self.filter_items(items) | 268 | items = self.filter_items(items) |
| 269 | 269 | ||
| 270 | if (len(items) < self.batch_size): | ||
| 271 | items = (items * self.batch_size)[:self.batch_size] | ||
| 272 | |||
| 270 | num_images = len(items) | 273 | num_images = len(items) |
| 271 | 274 | ||
| 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
