From 9b808b6ca102cfec0c273626a0bcadf897b7c942 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 19 Dec 2022 21:10:58 +0100 Subject: Improved dataset prompt handling, fixed --- data/csv.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 053457b..6525e45 100644 --- a/data/csv.py +++ b/data/csv.py @@ -16,26 +16,29 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def shuffle_prompt(prompt: str, dropout: float = 0): - def handle_block(block: str): - words = block.split(", ") - words = [w for w in words if w != ""] - if dropout != 0: - words = [w for w in words if np.random.random() > dropout] - np.random.shuffle(words) - return ", ".join(words) - - prompt = prompt.split(". ") - prompt = [handle_block(b) for b in prompt if b != ""] +def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: + if dropout != 0: + prompt = [keyword for keyword in prompt if np.random.random() > dropout] np.random.shuffle(prompt) - prompt = ". ".join(prompt) - return prompt + return ", ".join(prompt) + + +def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: + def expand_keyword(keyword: str) -> list[str]: + return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + + return [ + kw + for keyword in prompt.split(", ") + for kw in expand_keyword(keyword) + if keyword != "" + ] class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path - prompt: str + prompt: list[str] nprompt: str @@ -91,7 +94,7 @@ class CSVDataModule(pl.LightningDataModule): self.num_workers = num_workers self.batch_size = batch_size - def prepare_items(self, template, data) -> list[CSVDataItem]: + def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" @@ -100,7 +103,8 @@ class CSVDataModule(pl.LightningDataModule): CSVDataItem( self.data_root.joinpath(image.format(item["image"])), None, - prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), + prompt_to_keywords(prompt.format( + **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data @@ -130,6 +134,7 @@ class CSVDataModule(pl.LightningDataModule): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} + expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] if self.mode is not None: @@ -138,7 +143,7 @@ class CSVDataModule(pl.LightningDataModule): for item in items if "mode" in item and self.mode in item["mode"] ] - items = self.prepare_items(template, items) + items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) @@ -255,7 +260,7 @@ class CSVDataset(Dataset): example = {} - example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) + example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) -- cgit v1.2.3-70-g09d2