diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 053457b..6525e45 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -16,26 +16,29 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
| 16 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | def shuffle_prompt(prompt: str, dropout: float = 0): | 19 | def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: |
| 20 | def handle_block(block: str): | 20 | if dropout != 0: |
| 21 | words = block.split(", ") | 21 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] |
| 22 | words = [w for w in words if w != ""] | ||
| 23 | if dropout != 0: | ||
| 24 | words = [w for w in words if np.random.random() > dropout] | ||
| 25 | np.random.shuffle(words) | ||
| 26 | return ", ".join(words) | ||
| 27 | |||
| 28 | prompt = prompt.split(". ") | ||
| 29 | prompt = [handle_block(b) for b in prompt if b != ""] | ||
| 30 | np.random.shuffle(prompt) | 22 | np.random.shuffle(prompt) |
| 31 | prompt = ". ".join(prompt) | 23 | return ", ".join(prompt) |
| 32 | return prompt | 24 | |
| 25 | |||
| 26 | def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: | ||
| 27 | def expand_keyword(keyword: str) -> list[str]: | ||
| 28 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
| 29 | |||
| 30 | return [ | ||
| 31 | kw | ||
| 32 | for keyword in prompt.split(", ") | ||
| 33 | for kw in expand_keyword(keyword) | ||
| 34 | if keyword != "" | ||
| 35 | ] | ||
| 33 | 36 | ||
| 34 | 37 | ||
| 35 | class CSVDataItem(NamedTuple): | 38 | class CSVDataItem(NamedTuple): |
| 36 | instance_image_path: Path | 39 | instance_image_path: Path |
| 37 | class_image_path: Path | 40 | class_image_path: Path |
| 38 | prompt: str | 41 | prompt: list[str] |
| 39 | nprompt: str | 42 | nprompt: str |
| 40 | 43 | ||
| 41 | 44 | ||
| @@ -91,7 +94,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 91 | self.num_workers = num_workers | 94 | self.num_workers = num_workers |
| 92 | self.batch_size = batch_size | 95 | self.batch_size = batch_size |
| 93 | 96 | ||
| 94 | def prepare_items(self, template, data) -> list[CSVDataItem]: | 97 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
| 95 | image = template["image"] if "image" in template else "{}" | 98 | image = template["image"] if "image" in template else "{}" |
| 96 | prompt = template["prompt"] if "prompt" in template else "{content}" | 99 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 100 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| @@ -100,7 +103,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 100 | CSVDataItem( | 103 | CSVDataItem( |
| 101 | self.data_root.joinpath(image.format(item["image"])), | 104 | self.data_root.joinpath(image.format(item["image"])), |
| 102 | None, | 105 | None, |
| 103 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 106 | prompt_to_keywords(prompt.format( |
| 107 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | ||
| 104 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 105 | ) | 109 | ) |
| 106 | for item in data | 110 | for item in data |
| @@ -130,6 +134,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 130 | with open(self.data_file, 'rt') as f: | 134 | with open(self.data_file, 'rt') as f: |
| 131 | metadata = json.load(f) | 135 | metadata = json.load(f) |
| 132 | template = metadata[self.template_key] if self.template_key in metadata else {} | 136 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 137 | expansions = metadata["expansions"] if "expansions" in metadata else {} | ||
| 133 | items = metadata["items"] if "items" in metadata else [] | 138 | items = metadata["items"] if "items" in metadata else [] |
| 134 | 139 | ||
| 135 | if self.mode is not None: | 140 | if self.mode is not None: |
| @@ -138,7 +143,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 138 | for item in items | 143 | for item in items |
| 139 | if "mode" in item and self.mode in item["mode"] | 144 | if "mode" in item and self.mode in item["mode"] |
| 140 | ] | 145 | ] |
| 141 | items = self.prepare_items(template, items) | 146 | items = self.prepare_items(template, expansions, items) |
| 142 | items = self.filter_items(items) | 147 | items = self.filter_items(items) |
| 143 | 148 | ||
| 144 | num_images = len(items) | 149 | num_images = len(items) |
| @@ -255,7 +260,7 @@ class CSVDataset(Dataset): | |||
| 255 | 260 | ||
| 256 | example = {} | 261 | example = {} |
| 257 | 262 | ||
| 258 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
| 259 | example["nprompts"] = unprocessed_example["nprompts"] | 264 | example["nprompts"] = unprocessed_example["nprompts"] |
| 260 | 265 | ||
| 261 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 266 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
