diff options
author | Volpeon <git@volpeon.ink> | 2022-12-19 21:10:58 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-19 21:10:58 +0100 |
commit | 9b808b6ca102cfec0c273626a0bcadf897b7c942 (patch) | |
tree | 446311b3c6dca74ac9f9f4e055e2eba5f9cae9e5 /data | |
parent | Avoid increased VRAM usage on validation (diff) | |
download | textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.gz textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.bz2 textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.zip |
Improved dataset prompt handling, fixed
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 053457b..6525e45 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -16,26 +16,29 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
16 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
17 | 17 | ||
18 | 18 | ||
19 | def shuffle_prompt(prompt: str, dropout: float = 0): | 19 | def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: |
20 | def handle_block(block: str): | 20 | if dropout != 0: |
21 | words = block.split(", ") | 21 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] |
22 | words = [w for w in words if w != ""] | ||
23 | if dropout != 0: | ||
24 | words = [w for w in words if np.random.random() > dropout] | ||
25 | np.random.shuffle(words) | ||
26 | return ", ".join(words) | ||
27 | |||
28 | prompt = prompt.split(". ") | ||
29 | prompt = [handle_block(b) for b in prompt if b != ""] | ||
30 | np.random.shuffle(prompt) | 22 | np.random.shuffle(prompt) |
31 | prompt = ". ".join(prompt) | 23 | return ", ".join(prompt) |
32 | return prompt | 24 | |
25 | |||
26 | def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: | ||
27 | def expand_keyword(keyword: str) -> list[str]: | ||
28 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
29 | |||
30 | return [ | ||
31 | kw | ||
32 | for keyword in prompt.split(", ") | ||
33 | for kw in expand_keyword(keyword) | ||
34 | if keyword != "" | ||
35 | ] | ||
33 | 36 | ||
34 | 37 | ||
35 | class CSVDataItem(NamedTuple): | 38 | class CSVDataItem(NamedTuple): |
36 | instance_image_path: Path | 39 | instance_image_path: Path |
37 | class_image_path: Path | 40 | class_image_path: Path |
38 | prompt: str | 41 | prompt: list[str] |
39 | nprompt: str | 42 | nprompt: str |
40 | 43 | ||
41 | 44 | ||
@@ -91,7 +94,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
91 | self.num_workers = num_workers | 94 | self.num_workers = num_workers |
92 | self.batch_size = batch_size | 95 | self.batch_size = batch_size |
93 | 96 | ||
94 | def prepare_items(self, template, data) -> list[CSVDataItem]: | 97 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
95 | image = template["image"] if "image" in template else "{}" | 98 | image = template["image"] if "image" in template else "{}" |
96 | prompt = template["prompt"] if "prompt" in template else "{content}" | 99 | prompt = template["prompt"] if "prompt" in template else "{content}" |
97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 100 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
@@ -100,7 +103,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
100 | CSVDataItem( | 103 | CSVDataItem( |
101 | self.data_root.joinpath(image.format(item["image"])), | 104 | self.data_root.joinpath(image.format(item["image"])), |
102 | None, | 105 | None, |
103 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 106 | prompt_to_keywords(prompt.format( |
107 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | ||
104 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
105 | ) | 109 | ) |
106 | for item in data | 110 | for item in data |
@@ -130,6 +134,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
130 | with open(self.data_file, 'rt') as f: | 134 | with open(self.data_file, 'rt') as f: |
131 | metadata = json.load(f) | 135 | metadata = json.load(f) |
132 | template = metadata[self.template_key] if self.template_key in metadata else {} | 136 | template = metadata[self.template_key] if self.template_key in metadata else {} |
137 | expansions = metadata["expansions"] if "expansions" in metadata else {} | ||
133 | items = metadata["items"] if "items" in metadata else [] | 138 | items = metadata["items"] if "items" in metadata else [] |
134 | 139 | ||
135 | if self.mode is not None: | 140 | if self.mode is not None: |
@@ -138,7 +143,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
138 | for item in items | 143 | for item in items |
139 | if "mode" in item and self.mode in item["mode"] | 144 | if "mode" in item and self.mode in item["mode"] |
140 | ] | 145 | ] |
141 | items = self.prepare_items(template, items) | 146 | items = self.prepare_items(template, expansions, items) |
142 | items = self.filter_items(items) | 147 | items = self.filter_items(items) |
143 | 148 | ||
144 | num_images = len(items) | 149 | num_images = len(items) |
@@ -255,7 +260,7 @@ class CSVDataset(Dataset): | |||
255 | 260 | ||
256 | example = {} | 261 | example = {} |
257 | 262 | ||
258 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
259 | example["nprompts"] = unprocessed_example["nprompts"] | 264 | example["nprompts"] = unprocessed_example["nprompts"] |
260 | 265 | ||
261 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 266 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |