diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 67 | ||||
| -rw-r--r-- | data/keywords.py | 12 |
2 files changed, 37 insertions, 42 deletions
diff --git a/data/csv.py b/data/csv.py index c00ea07..d0ac317 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset, DataLoader, random_split | |||
| 11 | from torchvision import transforms | 11 | from torchvision import transforms |
| 12 | from transformers import CLIPTokenizer | 12 | from transformers import CLIPTokenizer |
| 13 | 13 | ||
| 14 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import str_to_keywords, keywords_to_str |
| 15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids |
| 16 | 16 | ||
| 17 | 17 | ||
| @@ -37,7 +37,7 @@ def get_image(path): | |||
| 37 | return image | 37 | return image |
| 38 | 38 | ||
| 39 | 39 | ||
| 40 | def prepare_prompt(prompt: Union[str, dict[str, str]]): | 40 | def prepare_tpl_slots(prompt: Union[str, dict[str, str]]): |
| 41 | return {"content": prompt} if isinstance(prompt, str) else prompt | 41 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 42 | 42 | ||
| 43 | 43 | ||
| @@ -135,11 +135,18 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
| 135 | class VlpnDataItem(NamedTuple): | 135 | class VlpnDataItem(NamedTuple): |
| 136 | instance_image_path: Path | 136 | instance_image_path: Path |
| 137 | class_image_path: Path | 137 | class_image_path: Path |
| 138 | prompt: list[str] | 138 | keywords: list[str] |
| 139 | prompt: str | ||
| 139 | cprompt: str | 140 | cprompt: str |
| 140 | nprompt: list[str] | 141 | nprompt: str |
| 141 | collection: list[str] | 142 | collection: list[str] |
| 142 | 143 | ||
| 144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): | ||
| 145 | prompt = self.prompt | ||
| 146 | if len(self.keywords): | ||
| 147 | prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle) | ||
| 148 | return prompt | ||
| 149 | |||
| 143 | 150 | ||
| 144 | def keyword_filter( | 151 | def keyword_filter( |
| 145 | placeholder_tokens: Optional[list[str]], | 152 | placeholder_tokens: Optional[list[str]], |
| @@ -147,10 +154,11 @@ def keyword_filter( | |||
| 147 | exclude_collections: Optional[list[str]], | 154 | exclude_collections: Optional[list[str]], |
| 148 | item: VlpnDataItem | 155 | item: VlpnDataItem |
| 149 | ): | 156 | ): |
| 157 | full_prompt = item.full_prompt() | ||
| 158 | |||
| 150 | cond1 = placeholder_tokens is None or any( | 159 | cond1 = placeholder_tokens is None or any( |
| 151 | keyword in part | 160 | token in full_prompt |
| 152 | for keyword in placeholder_tokens | 161 | for token in placeholder_tokens |
| 153 | for part in item.prompt | ||
| 154 | ) | 162 | ) |
| 155 | cond2 = collections is None or any( | 163 | cond2 = collections is None or any( |
| 156 | collection in item.collection | 164 | collection in item.collection |
| @@ -224,6 +232,7 @@ class VlpnDataModule(): | |||
| 224 | 232 | ||
| 225 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 233 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 226 | tpl_image = template["image"] if "image" in template else "{}" | 234 | tpl_image = template["image"] if "image" in template else "{}" |
| 235 | tpl_keywords = template["keywords"] if "keywords" in template else "{content}" | ||
| 227 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" | 236 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" |
| 228 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 237 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
| 229 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 238 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| @@ -232,37 +241,26 @@ class VlpnDataModule(): | |||
| 232 | 241 | ||
| 233 | for item in data: | 242 | for item in data: |
| 234 | image = tpl_image.format(item["image"]) | 243 | image = tpl_image.format(item["image"]) |
| 235 | prompt = item["prompt"] if "prompt" in item else "" | 244 | keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "") |
| 236 | nprompt = item["nprompt"] if "nprompt" in item else "" | 245 | prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "") |
| 246 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | ||
| 237 | collection = item["collection"].split(", ") if "collection" in item else [] | 247 | collection = item["collection"].split(", ") if "collection" in item else [] |
| 238 | 248 | ||
| 239 | prompt_keywords = prompt_to_keywords( | 249 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) |
| 240 | tpl_prompt.format(**prepare_prompt(prompt)), | ||
| 241 | expansions | ||
| 242 | ) | ||
| 243 | |||
| 244 | cprompt = keywords_to_prompt(prompt_to_keywords( | ||
| 245 | tpl_cprompt.format(**prepare_prompt(prompt)), | ||
| 246 | expansions | ||
| 247 | )) | ||
| 248 | 250 | ||
| 249 | inverted_tokens = keywords_to_prompt([ | 251 | inverted_tokens = keywords_to_str([ |
| 250 | f"inv_{token}" | 252 | f"inv_{token}" |
| 251 | for token in self.placeholder_tokens | 253 | for token in self.placeholder_tokens |
| 252 | if token in prompt_keywords | 254 | if token in saturated_keywords |
| 253 | ]) | 255 | ]) |
| 254 | 256 | ||
| 255 | nprompt_keywords = prompt_to_keywords( | ||
| 256 | tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), | ||
| 257 | expansions | ||
| 258 | ) | ||
| 259 | |||
| 260 | items.append(VlpnDataItem( | 257 | items.append(VlpnDataItem( |
| 261 | self.data_root / image, | 258 | self.data_root / image, |
| 262 | None, | 259 | None, |
| 263 | prompt_keywords, | 260 | saturated_keywords, |
| 264 | cprompt, | 261 | tpl_prompt.format(**prompt), |
| 265 | nprompt_keywords, | 262 | tpl_cprompt.format(**prompt), |
| 263 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | ||
| 266 | collection | 264 | collection |
| 267 | )) | 265 | )) |
| 268 | 266 | ||
| @@ -281,6 +279,7 @@ class VlpnDataModule(): | |||
| 281 | VlpnDataItem( | 279 | VlpnDataItem( |
| 282 | item.instance_image_path, | 280 | item.instance_image_path, |
| 283 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 281 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", |
| 282 | item.keywords, | ||
| 284 | item.prompt, | 283 | item.prompt, |
| 285 | item.cprompt, | 284 | item.cprompt, |
| 286 | item.nprompt, | 285 | item.nprompt, |
| @@ -473,15 +472,11 @@ class VlpnDataset(IterableDataset): | |||
| 473 | 472 | ||
| 474 | example = {} | 473 | example = {} |
| 475 | 474 | ||
| 476 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 475 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
| 477 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) | 476 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 478 | 477 | ||
| 479 | example["instance_prompt_ids"] = self.get_input_ids( | 478 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) |
| 480 | keywords_to_prompt(item.prompt, self.dropout, True) | 479 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
| 481 | ) | ||
| 482 | example["negative_prompt_ids"] = self.get_input_ids( | ||
| 483 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
| 484 | ) | ||
| 485 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 480 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 486 | 481 | ||
| 487 | if self.num_class_images != 0: | 482 | if self.num_class_images != 0: |
diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -1,21 +1,21 @@ | |||
| 1 | import numpy as np | 1 | import numpy as np |
| 2 | 2 | ||
| 3 | 3 | ||
| 4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 4 | def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: |
| 5 | if dropout != 0: | 5 | if dropout != 0: |
| 6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| 7 | if shuffle: | 7 | if shuffle: |
| 8 | np.random.shuffle(prompt) | 8 | np.random.shuffle(keywords) |
| 9 | return ", ".join(prompt) | 9 | return ", ".join(keywords) |
| 10 | 10 | ||
| 11 | 11 | ||
| 12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | 12 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
| 13 | def expand_keyword(keyword: str) -> list[str]: | 13 | def expand_keyword(keyword: str) -> list[str]: |
| 14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] |
| 15 | 15 | ||
| 16 | return [ | 16 | return [ |
| 17 | kw | 17 | kw |
| 18 | for keyword in prompt.split(", ") | 18 | for keyword in s.split(", ") |
| 19 | for kw in expand_keyword(keyword) | 19 | for kw in expand_keyword(keyword) |
| 20 | if keyword != "" | 20 | if keyword != "" |
| 21 | ] | 21 | ] |
