diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 15 |
1 files changed, 4 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): | |||
41 | prompt: list[str] | 41 | prompt: list[str] |
42 | cprompt: str | 42 | cprompt: str |
43 | nprompt: str | 43 | nprompt: str |
44 | mode: list[str] | ||
44 | 45 | ||
45 | 46 | ||
46 | class CSVDataModule(): | 47 | class CSVDataModule(): |
@@ -56,7 +57,6 @@ class CSVDataModule(): | |||
56 | dropout: float = 0, | 57 | dropout: float = 0, |
57 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
58 | center_crop: bool = False, | 59 | center_crop: bool = False, |
59 | mode: Optional[str] = None, | ||
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
@@ -81,7 +81,6 @@ class CSVDataModule(): | |||
81 | self.repeats = repeats | 81 | self.repeats = repeats |
82 | self.dropout = dropout | 82 | self.dropout = dropout |
83 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
84 | self.mode = mode | ||
85 | self.template_key = template_key | 84 | self.template_key = template_key |
86 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
87 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
@@ -113,6 +112,7 @@ class CSVDataModule(): | |||
113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
114 | expansions | 113 | expansions |
115 | )), | 114 | )), |
115 | item["mode"].split(", ") if "mode" in item else [] | ||
116 | ) | 116 | ) |
117 | for item in data | 117 | for item in data |
118 | ] | 118 | ] |
@@ -133,6 +133,7 @@ class CSVDataModule(): | |||
133 | item.prompt, | 133 | item.prompt, |
134 | item.cprompt, | 134 | item.cprompt, |
135 | item.nprompt, | 135 | item.nprompt, |
136 | item.mode, | ||
136 | ) | 137 | ) |
137 | for item in items | 138 | for item in items |
138 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
@@ -145,20 +146,12 @@ class CSVDataModule(): | |||
145 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 146 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
146 | items = metadata["items"] if "items" in metadata else [] | 147 | items = metadata["items"] if "items" in metadata else [] |
147 | 148 | ||
148 | if self.mode is not None: | ||
149 | items = [ | ||
150 | item | ||
151 | for item in items | ||
152 | if "mode" in item and self.mode in item["mode"].split(", ") | ||
153 | ] | ||
154 | items = self.prepare_items(template, expansions, items) | 149 | items = self.prepare_items(template, expansions, items) |
155 | items = self.filter_items(items) | 150 | items = self.filter_items(items) |
156 | 151 | ||
157 | num_images = len(items) | 152 | num_images = len(items) |
158 | 153 | ||
159 | valid_set_size = int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) |
160 | if self.valid_set_size: | ||
161 | valid_set_size = min(valid_set_size, self.valid_set_size) | ||
162 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
163 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
164 | 157 | ||