From dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 30 Dec 2022 13:48:26 +0100
Subject: Training script improvements

---
 data/csv.py | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

(limited to 'data')

diff --git a/data/csv.py b/data/csv.py
index 0ad36dc..4da5d64 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple):
     prompt: list[str]
     cprompt: str
     nprompt: str
+    mode: list[str]
 
 
 class CSVDataModule():
@@ -56,7 +57,6 @@ class CSVDataModule():
             dropout: float = 0,
             interpolation: str = "bicubic",
             center_crop: bool = False,
-            mode: Optional[str] = None,
             template_key: str = "template",
             valid_set_size: Optional[int] = None,
             generator: Optional[torch.Generator] = None,
@@ -81,7 +81,6 @@ class CSVDataModule():
         self.repeats = repeats
         self.dropout = dropout
         self.center_crop = center_crop
-        self.mode = mode
         self.template_key = template_key
         self.interpolation = interpolation
         self.valid_set_size = valid_set_size
@@ -113,6 +112,7 @@ class CSVDataModule():
                     nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
                     expansions
                 )),
+                item["mode"].split(", ") if "mode" in item else []
             )
             for item in data
         ]
@@ -133,6 +133,7 @@ class CSVDataModule():
                 item.prompt,
                 item.cprompt,
                 item.nprompt,
+                item.mode,
             )
             for item in items
             for i in range(image_multiplier)
@@ -145,20 +146,12 @@ class CSVDataModule():
         expansions = metadata["expansions"] if "expansions" in metadata else {}
         items = metadata["items"] if "items" in metadata else []
 
-        if self.mode is not None:
-            items = [
-                item
-                for item in items
-                if "mode" in item and self.mode in item["mode"].split(", ")
-            ]
         items = self.prepare_items(template, expansions, items)
         items = self.filter_items(items)
 
         num_images = len(items)
 
-        valid_set_size = int(num_images * 0.1)
-        if self.valid_set_size:
-            valid_set_size = min(valid_set_size, self.valid_set_size)
+        valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1)
         valid_set_size = max(valid_set_size, 1)
         train_set_size = num_images - valid_set_size
 
-- 
cgit v1.2.3-70-g09d2