summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
committerVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
commit9b808b6ca102cfec0c273626a0bcadf897b7c942 (patch)
tree446311b3c6dca74ac9f9f4e055e2eba5f9cae9e5 /data
parentAvoid increased VRAM usage on validation (diff)
downloadtextual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.gz
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.bz2
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.zip
Improved dataset prompt handling, fixed
Diffstat (limited to 'data')
-rw-r--r--data/csv.py41
1 files changed, 23 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index 053457b..6525e45 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -16,26 +16,29 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
16 return {"content": prompt} if isinstance(prompt, str) else prompt 16 return {"content": prompt} if isinstance(prompt, str) else prompt
17 17
18 18
19def shuffle_prompt(prompt: str, dropout: float = 0): 19def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str:
20 def handle_block(block: str): 20 if dropout != 0:
21 words = block.split(", ") 21 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
22 words = [w for w in words if w != ""]
23 if dropout != 0:
24 words = [w for w in words if np.random.random() > dropout]
25 np.random.shuffle(words)
26 return ", ".join(words)
27
28 prompt = prompt.split(". ")
29 prompt = [handle_block(b) for b in prompt if b != ""]
30 np.random.shuffle(prompt) 22 np.random.shuffle(prompt)
31 prompt = ". ".join(prompt) 23 return ", ".join(prompt)
32 return prompt 24
25
26def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]:
27 def expand_keyword(keyword: str) -> list[str]:
28 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
29
30 return [
31 kw
32 for keyword in prompt.split(", ")
33 for kw in expand_keyword(keyword)
34 if keyword != ""
35 ]
33 36
34 37
35class CSVDataItem(NamedTuple): 38class CSVDataItem(NamedTuple):
36 instance_image_path: Path 39 instance_image_path: Path
37 class_image_path: Path 40 class_image_path: Path
38 prompt: str 41 prompt: list[str]
39 nprompt: str 42 nprompt: str
40 43
41 44
@@ -91,7 +94,7 @@ class CSVDataModule(pl.LightningDataModule):
91 self.num_workers = num_workers 94 self.num_workers = num_workers
92 self.batch_size = batch_size 95 self.batch_size = batch_size
93 96
94 def prepare_items(self, template, data) -> list[CSVDataItem]: 97 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]:
95 image = template["image"] if "image" in template else "{}" 98 image = template["image"] if "image" in template else "{}"
96 prompt = template["prompt"] if "prompt" in template else "{content}" 99 prompt = template["prompt"] if "prompt" in template else "{content}"
97 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 100 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
@@ -100,7 +103,8 @@ class CSVDataModule(pl.LightningDataModule):
100 CSVDataItem( 103 CSVDataItem(
101 self.data_root.joinpath(image.format(item["image"])), 104 self.data_root.joinpath(image.format(item["image"])),
102 None, 105 None,
103 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 106 prompt_to_keywords(prompt.format(
107 **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions),
104 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 108 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
105 ) 109 )
106 for item in data 110 for item in data
@@ -130,6 +134,7 @@ class CSVDataModule(pl.LightningDataModule):
130 with open(self.data_file, 'rt') as f: 134 with open(self.data_file, 'rt') as f:
131 metadata = json.load(f) 135 metadata = json.load(f)
132 template = metadata[self.template_key] if self.template_key in metadata else {} 136 template = metadata[self.template_key] if self.template_key in metadata else {}
137 expansions = metadata["expansions"] if "expansions" in metadata else {}
133 items = metadata["items"] if "items" in metadata else [] 138 items = metadata["items"] if "items" in metadata else []
134 139
135 if self.mode is not None: 140 if self.mode is not None:
@@ -138,7 +143,7 @@ class CSVDataModule(pl.LightningDataModule):
138 for item in items 143 for item in items
139 if "mode" in item and self.mode in item["mode"] 144 if "mode" in item and self.mode in item["mode"]
140 ] 145 ]
141 items = self.prepare_items(template, items) 146 items = self.prepare_items(template, expansions, items)
142 items = self.filter_items(items) 147 items = self.filter_items(items)
143 148
144 num_images = len(items) 149 num_images = len(items)
@@ -255,7 +260,7 @@ class CSVDataset(Dataset):
255 260
256 example = {} 261 example = {}
257 262
258 example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) 263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout)
259 example["nprompts"] = unprocessed_example["nprompts"] 264 example["nprompts"] = unprocessed_example["nprompts"]
260 265
261 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 266 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])