From dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 13:48:26 +0100 Subject: Training script improvements --- data/csv.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): prompt: list[str] cprompt: str nprompt: str + mode: list[str] class CSVDataModule(): @@ -56,7 +57,6 @@ class CSVDataModule(): dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, - mode: Optional[str] = None, template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, @@ -81,7 +81,6 @@ class CSVDataModule(): self.repeats = repeats self.dropout = dropout self.center_crop = center_crop - self.mode = mode self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -113,6 +112,7 @@ class CSVDataModule(): nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), + item["mode"].split(", ") if "mode" in item else [] ) for item in data ] @@ -133,6 +133,7 @@ class CSVDataModule(): item.prompt, item.cprompt, item.nprompt, + item.mode, ) for item in items for i in range(image_multiplier) @@ -145,20 +146,12 @@ class CSVDataModule(): expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] - if self.mode is not None: - items = [ - item - for item in items - if "mode" in item and self.mode in item["mode"].split(", ") - ] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) - valid_set_size = int(num_images * 0.1) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size -- cgit v1.2.3-70-g09d2