diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 67 |
1 files changed, 44 insertions, 23 deletions
diff --git a/data/csv.py b/data/csv.py index d52d251..9770bec 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -178,6 +178,7 @@ class VlpnDataModule(): | |||
178 | shuffle: bool = False, | 178 | shuffle: bool = False, |
179 | interpolation: str = "bicubic", | 179 | interpolation: str = "bicubic", |
180 | template_key: str = "template", | 180 | template_key: str = "template", |
181 | placeholder_tokens: list[str] = [], | ||
181 | valid_set_size: Optional[int] = None, | 182 | valid_set_size: Optional[int] = None, |
182 | train_set_pad: Optional[int] = None, | 183 | train_set_pad: Optional[int] = None, |
183 | valid_set_pad: Optional[int] = None, | 184 | valid_set_pad: Optional[int] = None, |
@@ -195,6 +196,7 @@ class VlpnDataModule(): | |||
195 | self.data_root = self.data_file.parent | 196 | self.data_root = self.data_file.parent |
196 | self.class_root = self.data_root / class_subdir | 197 | self.class_root = self.data_root / class_subdir |
197 | self.class_root.mkdir(parents=True, exist_ok=True) | 198 | self.class_root.mkdir(parents=True, exist_ok=True) |
199 | self.placeholder_tokens = placeholder_tokens | ||
198 | self.num_class_images = num_class_images | 200 | self.num_class_images = num_class_images |
199 | self.with_guidance = with_guidance | 201 | self.with_guidance = with_guidance |
200 | 202 | ||
@@ -217,31 +219,50 @@ class VlpnDataModule(): | |||
217 | self.dtype = dtype | 219 | self.dtype = dtype |
218 | 220 | ||
219 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 221 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
220 | image = template["image"] if "image" in template else "{}" | 222 | tpl_image = template["image"] if "image" in template else "{}" |
221 | prompt = template["prompt"] if "prompt" in template else "{content}" | 223 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" |
222 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 224 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
223 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 225 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
224 | 226 | ||
225 | return [ | 227 | items = [] |
226 | VlpnDataItem( | 228 | |
227 | self.data_root / image.format(item["image"]), | 229 | for item in data: |
228 | None, | 230 | image = tpl_image.format(item["image"]) |
229 | prompt_to_keywords( | 231 | prompt = item["prompt"] if "prompt" in item else "" |
230 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | nprompt = item["nprompt"] if "nprompt" in item else "" |
231 | expansions | 233 | collection = item["collection"].split(", ") if "collection" in item else [] |
232 | ), | 234 | |
233 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_keywords = prompt_to_keywords( |
234 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 236 | tpl_prompt.format(**prepare_prompt(prompt)), |
235 | expansions | 237 | expansions |
236 | )), | ||
237 | prompt_to_keywords( | ||
238 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | ||
239 | expansions | ||
240 | ), | ||
241 | item["collection"].split(", ") if "collection" in item else [] | ||
242 | ) | 238 | ) |
243 | for item in data | 239 | |
244 | ] | 240 | cprompt = keywords_to_prompt(prompt_to_keywords( |
241 | tpl_cprompt.format(**prepare_prompt(prompt)), | ||
242 | expansions | ||
243 | )) | ||
244 | |||
245 | inverted_tokens = keywords_to_prompt([ | ||
246 | f"inv_{token}" | ||
247 | for token in self.placeholder_tokens | ||
248 | if token in prompt_keywords | ||
249 | ]) | ||
250 | |||
251 | nprompt_keywords = prompt_to_keywords( | ||
252 | tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), | ||
253 | expansions | ||
254 | ) | ||
255 | |||
256 | items.append(VlpnDataItem( | ||
257 | self.data_root / image, | ||
258 | None, | ||
259 | prompt_keywords, | ||
260 | cprompt, | ||
261 | nprompt_keywords, | ||
262 | collection | ||
263 | )) | ||
264 | |||
265 | return items | ||
245 | 266 | ||
246 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: | 267 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: |
247 | if self.filter is None: | 268 | if self.filter is None: |