From 3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 23:02:01 +0100 Subject: Better dataset prompt handling --- data/csv.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index edce2b1..265293b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -15,10 +15,11 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: +def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: if dropout != 0: prompt = [keyword for keyword in prompt if np.random.random() > dropout] - np.random.shuffle(prompt) + if shuffle: + np.random.shuffle(prompt) return ", ".join(prompt) @@ -38,8 +39,8 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] - cprompt: str - nprompt: str + cprompt: list[str] + nprompt: list[str] class CSVDataModule(): @@ -104,8 +105,14 @@ class CSVDataModule(): prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), - cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), - nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), + prompt_to_keywords( + cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), + expansions + ), + prompt_to_keywords( + prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), + expansions + ), ) for item in data ] @@ -253,9 +260,9 @@ class CSVDataset(Dataset): example = {} - example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) - example["cprompts"] = unprocessed_example["cprompts"] - example["nprompts"] = unprocessed_example["nprompts"] + example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) + example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) + example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) -- cgit v1.2.3-54-g00ecf