diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 67 | ||||
-rw-r--r-- | data/keywords.py | 12 |
2 files changed, 37 insertions, 42 deletions
diff --git a/data/csv.py b/data/csv.py index c00ea07..d0ac317 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset, DataLoader, random_split | |||
11 | from torchvision import transforms | 11 | from torchvision import transforms |
12 | from transformers import CLIPTokenizer | 12 | from transformers import CLIPTokenizer |
13 | 13 | ||
14 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import str_to_keywords, keywords_to_str |
15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids |
16 | 16 | ||
17 | 17 | ||
@@ -37,7 +37,7 @@ def get_image(path): | |||
37 | return image | 37 | return image |
38 | 38 | ||
39 | 39 | ||
40 | def prepare_prompt(prompt: Union[str, dict[str, str]]): | 40 | def prepare_tpl_slots(prompt: Union[str, dict[str, str]]): |
41 | return {"content": prompt} if isinstance(prompt, str) else prompt | 41 | return {"content": prompt} if isinstance(prompt, str) else prompt |
42 | 42 | ||
43 | 43 | ||
@@ -135,11 +135,18 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
135 | class VlpnDataItem(NamedTuple): | 135 | class VlpnDataItem(NamedTuple): |
136 | instance_image_path: Path | 136 | instance_image_path: Path |
137 | class_image_path: Path | 137 | class_image_path: Path |
138 | prompt: list[str] | 138 | keywords: list[str] |
139 | prompt: str | ||
139 | cprompt: str | 140 | cprompt: str |
140 | nprompt: list[str] | 141 | nprompt: str |
141 | collection: list[str] | 142 | collection: list[str] |
142 | 143 | ||
144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): | ||
145 | prompt = self.prompt | ||
146 | if len(self.keywords): | ||
147 | prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle) | ||
148 | return prompt | ||
149 | |||
143 | 150 | ||
144 | def keyword_filter( | 151 | def keyword_filter( |
145 | placeholder_tokens: Optional[list[str]], | 152 | placeholder_tokens: Optional[list[str]], |
@@ -147,10 +154,11 @@ def keyword_filter( | |||
147 | exclude_collections: Optional[list[str]], | 154 | exclude_collections: Optional[list[str]], |
148 | item: VlpnDataItem | 155 | item: VlpnDataItem |
149 | ): | 156 | ): |
157 | full_prompt = item.full_prompt() | ||
158 | |||
150 | cond1 = placeholder_tokens is None or any( | 159 | cond1 = placeholder_tokens is None or any( |
151 | keyword in part | 160 | token in full_prompt |
152 | for keyword in placeholder_tokens | 161 | for token in placeholder_tokens |
153 | for part in item.prompt | ||
154 | ) | 162 | ) |
155 | cond2 = collections is None or any( | 163 | cond2 = collections is None or any( |
156 | collection in item.collection | 164 | collection in item.collection |
@@ -224,6 +232,7 @@ class VlpnDataModule(): | |||
224 | 232 | ||
225 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 233 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
226 | tpl_image = template["image"] if "image" in template else "{}" | 234 | tpl_image = template["image"] if "image" in template else "{}" |
235 | tpl_keywords = template["keywords"] if "keywords" in template else "{content}" | ||
227 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" | 236 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" |
228 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 237 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
229 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 238 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
@@ -232,37 +241,26 @@ class VlpnDataModule(): | |||
232 | 241 | ||
233 | for item in data: | 242 | for item in data: |
234 | image = tpl_image.format(item["image"]) | 243 | image = tpl_image.format(item["image"]) |
235 | prompt = item["prompt"] if "prompt" in item else "" | 244 | keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "") |
236 | nprompt = item["nprompt"] if "nprompt" in item else "" | 245 | prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "") |
246 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | ||
237 | collection = item["collection"].split(", ") if "collection" in item else [] | 247 | collection = item["collection"].split(", ") if "collection" in item else [] |
238 | 248 | ||
239 | prompt_keywords = prompt_to_keywords( | 249 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) |
240 | tpl_prompt.format(**prepare_prompt(prompt)), | ||
241 | expansions | ||
242 | ) | ||
243 | |||
244 | cprompt = keywords_to_prompt(prompt_to_keywords( | ||
245 | tpl_cprompt.format(**prepare_prompt(prompt)), | ||
246 | expansions | ||
247 | )) | ||
248 | 250 | ||
249 | inverted_tokens = keywords_to_prompt([ | 251 | inverted_tokens = keywords_to_str([ |
250 | f"inv_{token}" | 252 | f"inv_{token}" |
251 | for token in self.placeholder_tokens | 253 | for token in self.placeholder_tokens |
252 | if token in prompt_keywords | 254 | if token in saturated_keywords |
253 | ]) | 255 | ]) |
254 | 256 | ||
255 | nprompt_keywords = prompt_to_keywords( | ||
256 | tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), | ||
257 | expansions | ||
258 | ) | ||
259 | |||
260 | items.append(VlpnDataItem( | 257 | items.append(VlpnDataItem( |
261 | self.data_root / image, | 258 | self.data_root / image, |
262 | None, | 259 | None, |
263 | prompt_keywords, | 260 | saturated_keywords, |
264 | cprompt, | 261 | tpl_prompt.format(**prompt), |
265 | nprompt_keywords, | 262 | tpl_cprompt.format(**prompt), |
263 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | ||
266 | collection | 264 | collection |
267 | )) | 265 | )) |
268 | 266 | ||
@@ -281,6 +279,7 @@ class VlpnDataModule(): | |||
281 | VlpnDataItem( | 279 | VlpnDataItem( |
282 | item.instance_image_path, | 280 | item.instance_image_path, |
283 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 281 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", |
282 | item.keywords, | ||
284 | item.prompt, | 283 | item.prompt, |
285 | item.cprompt, | 284 | item.cprompt, |
286 | item.nprompt, | 285 | item.nprompt, |
@@ -473,15 +472,11 @@ class VlpnDataset(IterableDataset): | |||
473 | 472 | ||
474 | example = {} | 473 | example = {} |
475 | 474 | ||
476 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 475 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
477 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) | 476 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
478 | 477 | ||
479 | example["instance_prompt_ids"] = self.get_input_ids( | 478 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) |
480 | keywords_to_prompt(item.prompt, self.dropout, True) | 479 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
481 | ) | ||
482 | example["negative_prompt_ids"] = self.get_input_ids( | ||
483 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
484 | ) | ||
485 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 480 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
486 | 481 | ||
487 | if self.num_class_images != 0: | 482 | if self.num_class_images != 0: |
diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -1,21 +1,21 @@ | |||
1 | import numpy as np | 1 | import numpy as np |
2 | 2 | ||
3 | 3 | ||
4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 4 | def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: |
5 | if dropout != 0: | 5 | if dropout != 0: |
6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
7 | if shuffle: | 7 | if shuffle: |
8 | np.random.shuffle(prompt) | 8 | np.random.shuffle(keywords) |
9 | return ", ".join(prompt) | 9 | return ", ".join(keywords) |
10 | 10 | ||
11 | 11 | ||
12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | 12 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
13 | def expand_keyword(keyword: str) -> list[str]: | 13 | def expand_keyword(keyword: str) -> list[str]: |
14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] |
15 | 15 | ||
16 | return [ | 16 | return [ |
17 | kw | 17 | kw |
18 | for keyword in prompt.split(", ") | 18 | for keyword in s.split(", ") |
19 | for kw in expand_keyword(keyword) | 19 | for kw in expand_keyword(keyword) |
20 | if keyword != "" | 20 | if keyword != "" |
21 | ] | 21 | ] |