From d0ce16b542deac464e097c38adc5095802bd6763 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 3 Oct 2022 11:44:42 +0200
Subject: Assign unused images in validation dataset to train dataset

---
 data/dreambooth/csv.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

(limited to 'data/dreambooth')

diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 1676d35..4087226 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -21,6 +21,7 @@ class CSVDataModule(pl.LightningDataModule):
                  interpolation="bicubic",
                  identifier="*",
                  center_crop=False,
+                 valid_set_size=None,
                  collate_fn=None):
         super().__init__()
 
@@ -39,6 +40,7 @@ class CSVDataModule(pl.LightningDataModule):
         self.identifier = identifier
         self.center_crop = center_crop
         self.interpolation = interpolation
+        self.valid_set_size = valid_set_size
         self.collate_fn = collate_fn
         self.batch_size = batch_size
 
@@ -50,8 +52,11 @@ class CSVDataModule(pl.LightningDataModule):
         self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
 
     def setup(self, stage=None):
-        train_set_size = int(len(self.data_full) * 0.8)
-        valid_set_size = len(self.data_full) - train_set_size
+        valid_set_size = int(len(self.data_full) * 0.2)
+        if self.valid_set_size:
+            valid_set_size = math.min(valid_set_size, self.valid_set_size)
+        train_set_size = len(self.data_full) - valid_set_size
+
         self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
 
         train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
-- 
cgit v1.2.3-70-g09d2