summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
commit8f4d212b3833041448678ad8a44a9a327934f74a (patch)
tree667edaef8a771a171db4a5afdae1fe8d427a2593 /data
parentMore generic datset filter (diff)
downloadtextual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.gz
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.bz2
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.zip
Avoid increased VRAM usage on validation
Diffstat (limited to 'data')
-rw-r--r--data/csv.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 20ac992..053457b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule):
54 dropout: float = 0, 54 dropout: float = 0,
55 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
56 center_crop: bool = False, 56 center_crop: bool = False,
57 mode: Optional[str] = None,
57 template_key: str = "template", 58 template_key: str = "template",
58 valid_set_size: Optional[int] = None, 59 valid_set_size: Optional[int] = None,
59 generator: Optional[torch.Generator] = None, 60 generator: Optional[torch.Generator] = None,
@@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule):
80 self.repeats = repeats 81 self.repeats = repeats
81 self.dropout = dropout 82 self.dropout = dropout
82 self.center_crop = center_crop 83 self.center_crop = center_crop
84 self.mode = mode
83 self.template_key = template_key 85 self.template_key = template_key
84 self.interpolation = interpolation 86 self.interpolation = interpolation
85 self.valid_set_size = valid_set_size 87 self.valid_set_size = valid_set_size
@@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule):
99 self.data_root.joinpath(image.format(item["image"])), 101 self.data_root.joinpath(image.format(item["image"])),
100 None, 102 None,
101 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 103 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
102 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) 104 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
103 ) 105 )
104 for item in data 106 for item in data
105 ] 107 ]
@@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule):
118 item.instance_image_path, 120 item.instance_image_path,
119 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 121 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
120 item.prompt, 122 item.prompt,
121 item.nprompt 123 item.nprompt,
122 ) 124 )
123 for item in items 125 for item in items
124 for i in range(image_multiplier) 126 for i in range(image_multiplier)
@@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule):
130 template = metadata[self.template_key] if self.template_key in metadata else {} 132 template = metadata[self.template_key] if self.template_key in metadata else {}
131 items = metadata["items"] if "items" in metadata else [] 133 items = metadata["items"] if "items" in metadata else []
132 134
133 items = [item for item in items if not "skip" in item or item["skip"] != True] 135 if self.mode is not None:
136 items = [
137 item
138 for item in items
139 if "mode" in item and self.mode in item["mode"]
140 ]
134 items = self.prepare_items(template, items) 141 items = self.prepare_items(template, items)
135 items = self.filter_items(items) 142 items = self.filter_items(items)
136 143