From 8f4d212b3833041448678ad8a44a9a327934f74a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 15 Dec 2022 20:30:59 +0100 Subject: Avoid increased VRAM usage on validation --- data/csv.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 20ac992..053457b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule): dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, + mode: Optional[str] = None, template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, @@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule): self.repeats = repeats self.dropout = dropout self.center_crop = center_crop + self.mode = mode self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root.joinpath(image.format(item["image"])), None, prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), - nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) + nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data ] @@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule): item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, - item.nprompt + item.nprompt, ) for item in items for i in range(image_multiplier) @@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule): template = metadata[self.template_key] if self.template_key in metadata else {} items = metadata["items"] if "items" in metadata else [] - items = [item for item in items if not "skip" in item or item["skip"] != True] + if self.mode is not None: + items = [ + item + for item in items + if "mode" in item and self.mode in item["mode"] + ] items = self.prepare_items(template, items) items = self.filter_items(items) -- cgit v1.2.3-70-g09d2