summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
commit3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de (patch)
tree7b12a26c195e7298bb6cbc993ad0dd0f322fede4 /data
parentnum_class_images is now class images per train image (diff)
downloadtextual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.gz
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.bz2
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.zip
Better dataset prompt handling
Diffstat (limited to 'data')
-rw-r--r--data/csv.py25
1 files changed, 16 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py
index edce2b1..265293b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -15,10 +15,11 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 15 return {"content": prompt} if isinstance(prompt, str) else prompt
16 16
17 17
18def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: 18def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
19 if dropout != 0: 19 if dropout != 0:
20 prompt = [keyword for keyword in prompt if np.random.random() > dropout] 20 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
21 np.random.shuffle(prompt) 21 if shuffle:
22 np.random.shuffle(prompt)
22 return ", ".join(prompt) 23 return ", ".join(prompt)
23 24
24 25
@@ -38,8 +39,8 @@ class CSVDataItem(NamedTuple):
38 instance_image_path: Path 39 instance_image_path: Path
39 class_image_path: Path 40 class_image_path: Path
40 prompt: list[str] 41 prompt: list[str]
41 cprompt: str 42 cprompt: list[str]
42 nprompt: str 43 nprompt: list[str]
43 44
44 45
45class CSVDataModule(): 46class CSVDataModule():
@@ -104,8 +105,14 @@ class CSVDataModule():
104 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
105 expansions 106 expansions
106 ), 107 ),
107 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), 108 prompt_to_keywords(
108 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 109 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")),
110 expansions
111 ),
112 prompt_to_keywords(
113 prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions
115 ),
109 ) 116 )
110 for item in data 117 for item in data
111 ] 118 ]
@@ -253,9 +260,9 @@ class CSVDataset(Dataset):
253 260
254 example = {} 261 example = {}
255 262
256 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) 263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True)
257 example["cprompts"] = unprocessed_example["cprompts"] 264 example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"])
258 example["nprompts"] = unprocessed_example["nprompts"] 265 example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"])
259 266
260 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
261 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) 268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])