From e3669927b47b5367a3348d30c4b318da84af661d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Apr 2023 11:14:03 +0200 Subject: Update dataset format: Separate prompt and keywords --- data/csv.py | 67 ++++++++++++++++++++++++++------------------------------ data/keywords.py | 12 +++++----- 2 files changed, 37 insertions(+), 42 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index c00ea07..d0ac317 100644 --- a/data/csv.py +++ b/data/csv.py @@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer -from data.keywords import prompt_to_keywords, keywords_to_prompt +from data.keywords import str_to_keywords, keywords_to_str from models.clip.util import unify_input_ids @@ -37,7 +37,7 @@ def get_image(path): return image -def prepare_prompt(prompt: Union[str, dict[str, str]]): +def prepare_tpl_slots(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt @@ -135,11 +135,18 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path - prompt: list[str] + keywords: list[str] + prompt: str cprompt: str - nprompt: list[str] + nprompt: str collection: list[str] + def full_prompt(self, dropout: float = 0, shuffle: bool = False): + prompt = self.prompt + if len(self.keywords): + prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle) + return prompt + def keyword_filter( placeholder_tokens: Optional[list[str]], @@ -147,10 +154,11 @@ def keyword_filter( exclude_collections: Optional[list[str]], item: VlpnDataItem ): + full_prompt = item.full_prompt() + cond1 = placeholder_tokens is None or any( - keyword in part - for keyword in placeholder_tokens - for part in item.prompt + token in full_prompt + for token in placeholder_tokens ) cond2 = collections is None or any( collection in item.collection @@ -224,6 +232,7 @@ class VlpnDataModule(): def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" + tpl_keywords = template["keywords"] if "keywords" in template else "{content}" tpl_prompt = template["prompt"] if "prompt" in template else "{content}" tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" @@ -232,37 +241,26 @@ class VlpnDataModule(): for item in data: image = tpl_image.format(item["image"]) - prompt = item["prompt"] if "prompt" in item else "" - nprompt = item["nprompt"] if "nprompt" in item else "" + keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "") + prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "") + nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") collection = item["collection"].split(", ") if "collection" in item else [] - prompt_keywords = prompt_to_keywords( - tpl_prompt.format(**prepare_prompt(prompt)), - expansions - ) - - cprompt = keywords_to_prompt(prompt_to_keywords( - tpl_cprompt.format(**prepare_prompt(prompt)), - expansions - )) + saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) - inverted_tokens = keywords_to_prompt([ + inverted_tokens = keywords_to_str([ f"inv_{token}" for token in self.placeholder_tokens - if token in prompt_keywords + if token in saturated_keywords ]) - nprompt_keywords = prompt_to_keywords( - tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), - expansions - ) - items.append(VlpnDataItem( self.data_root / image, None, - prompt_keywords, - cprompt, - nprompt_keywords, + saturated_keywords, + tpl_prompt.format(**prompt), + tpl_cprompt.format(**prompt), + tpl_nprompt.format(_inv=inverted_tokens, **nprompt), collection )) @@ -281,6 +279,7 @@ class VlpnDataModule(): VlpnDataItem( item.instance_image_path, self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", + item.keywords, item.prompt, item.cprompt, item.nprompt, @@ -473,15 +472,11 @@ class VlpnDataset(IterableDataset): example = {} - example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) + example["prompt_ids"] = self.get_input_ids(item.full_prompt()) + example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids( - keywords_to_prompt(item.prompt, self.dropout, True) - ) - example["negative_prompt_ids"] = self.get_input_ids( - keywords_to_prompt(item.nprompt, self.dropout, True) - ) + example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) + example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -1,21 +1,21 @@ import numpy as np -def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: +def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: if dropout != 0: - prompt = [keyword for keyword in prompt if np.random.random() > dropout] + keywords = [keyword for keyword in keywords if np.random.random() > dropout] if shuffle: - np.random.shuffle(prompt) - return ", ".join(prompt) + np.random.shuffle(keywords) + return ", ".join(keywords) -def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: +def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: def expand_keyword(keyword: str) -> list[str]: return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] return [ kw - for keyword in prompt.split(", ") + for keyword in s.split(", ") for kw in expand_keyword(keyword) if keyword != "" ] -- cgit v1.2.3-70-g09d2