summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 18:45:03 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 18:45:03 +0100
commit96e887da4be2c13f5f58da3359a9ab891c44d050 (patch)
treebe5824b20be60ddd73374a0b962561eaea62fdc1 /data
parentMoved multi-TI code from Dreambooth to TI script (diff)
downloadtextual-inversion-diff-96e887da4be2c13f5f58da3359a9ab891c44d050.tar.gz
textual-inversion-diff-96e887da4be2c13f5f58da3359a9ab891c44d050.tar.bz2
textual-inversion-diff-96e887da4be2c13f5f58da3359a9ab891c44d050.zip
If valid set size is 0, re-use one image from train set
Diffstat (limited to 'data')
-rw-r--r--data/csv.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index 6857b6f..85b98f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -282,7 +282,7 @@ class VlpnDataModule():
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
283 283
284 if valid_set_size == 0: 284 if valid_set_size == 0:
285 data_train, data_val = items, [] 285 data_train, data_val = items, items[:1]
286 else: 286 else:
287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
288 288