From e3669927b47b5367a3348d30c4b318da84af661d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Apr 2023 11:14:03 +0200 Subject: Update dataset format: Separate prompt and keywords --- data/csv.py | 67 +++++++++++++++++++++++--------------------------- data/keywords.py | 12 ++++----- infer.py | 4 +-- train_ti.py | 2 +- training/functional.py | 9 ++++--- 5 files changed, 46 insertions(+), 48 deletions(-) diff --git a/data/csv.py b/data/csv.py index c00ea07..d0ac317 100644 --- a/data/csv.py +++ b/data/csv.py @@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer -from data.keywords import prompt_to_keywords, keywords_to_prompt +from data.keywords import str_to_keywords, keywords_to_str from models.clip.util import unify_input_ids @@ -37,7 +37,7 @@ def get_image(path): return image -def prepare_prompt(prompt: Union[str, dict[str, str]]): +def prepare_tpl_slots(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt @@ -135,11 +135,18 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path - prompt: list[str] + keywords: list[str] + prompt: str cprompt: str - nprompt: list[str] + nprompt: str collection: list[str] + def full_prompt(self, dropout: float = 0, shuffle: bool = False): + prompt = self.prompt + if len(self.keywords): + prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle) + return prompt + def keyword_filter( placeholder_tokens: Optional[list[str]], @@ -147,10 +154,11 @@ def keyword_filter( exclude_collections: Optional[list[str]], item: VlpnDataItem ): + full_prompt = item.full_prompt() + cond1 = placeholder_tokens is None or any( - keyword in part - for keyword in placeholder_tokens - for part in item.prompt + token in full_prompt + for token in placeholder_tokens ) cond2 = collections is None or any( collection in item.collection @@ -224,6 +232,7 @@ class VlpnDataModule(): def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" + tpl_keywords = template["keywords"] if "keywords" in template else "{content}" tpl_prompt = template["prompt"] if "prompt" in template else "{content}" tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" @@ -232,37 +241,26 @@ class VlpnDataModule(): for item in data: image = tpl_image.format(item["image"]) - prompt = item["prompt"] if "prompt" in item else "" - nprompt = item["nprompt"] if "nprompt" in item else "" + keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "") + prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "") + nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") collection = item["collection"].split(", ") if "collection" in item else [] - prompt_keywords = prompt_to_keywords( - tpl_prompt.format(**prepare_prompt(prompt)), - expansions - ) - - cprompt = keywords_to_prompt(prompt_to_keywords( - tpl_cprompt.format(**prepare_prompt(prompt)), - expansions - )) + saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) - inverted_tokens = keywords_to_prompt([ + inverted_tokens = keywords_to_str([ f"inv_{token}" for token in self.placeholder_tokens - if token in prompt_keywords + if token in saturated_keywords ]) - nprompt_keywords = prompt_to_keywords( - tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), - expansions - ) - items.append(VlpnDataItem( self.data_root / image, None, - prompt_keywords, - cprompt, - nprompt_keywords, + saturated_keywords, + tpl_prompt.format(**prompt), + tpl_cprompt.format(**prompt), + tpl_nprompt.format(_inv=inverted_tokens, **nprompt), collection )) @@ -281,6 +279,7 @@ class VlpnDataModule(): VlpnDataItem( item.instance_image_path, self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", + item.keywords, item.prompt, item.cprompt, item.nprompt, @@ -473,15 +472,11 @@ class VlpnDataset(IterableDataset): example = {} - example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) + example["prompt_ids"] = self.get_input_ids(item.full_prompt()) + example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids( - keywords_to_prompt(item.prompt, self.dropout, True) - ) - example["negative_prompt_ids"] = self.get_input_ids( - keywords_to_prompt(item.nprompt, self.dropout, True) - ) + example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) + example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -1,21 +1,21 @@ import numpy as np -def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: +def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: if dropout != 0: - prompt = [keyword for keyword in prompt if np.random.random() > dropout] + keywords = [keyword for keyword in keywords if np.random.random() > dropout] if shuffle: - np.random.shuffle(prompt) - return ", ".join(prompt) + np.random.shuffle(keywords) + return ", ".join(keywords) -def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: +def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: def expand_keyword(keyword: str) -> list[str]: return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] return [ kw - for keyword in prompt.split(", ") + for keyword in s.split(", ") for kw in expand_keyword(keyword) if keyword != "" ] diff --git a/infer.py b/infer.py index cf59bba..ed86ab1 100644 --- a/infer.py +++ b/infer.py @@ -28,7 +28,7 @@ from diffusers import ( ) from transformers import CLIPTextModel -from data.keywords import prompt_to_keywords, keywords_to_prompt +from data.keywords import str_to_keywords, keywords_to_str from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -296,7 +296,7 @@ def create_pipeline(model, dtype): def shuffle_prompts(prompts: list[str]) -> list[str]: - return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] + return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] @torch.inference_mode() diff --git a/train_ti.py b/train_ti.py index 5482326..651dfbe 100644 --- a/train_ti.py +++ b/train_ti.py @@ -186,7 +186,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( diff --git a/training/functional.py b/training/functional.py index b9fb546..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py @@ -522,9 +522,12 @@ def train_loop( accelerator.wait_for_everyone() - lr = lr_scheduler.get_last_lr()[0] - if torch.is_tensor(lr): - lr = lr.item() + if isDadaptation: + lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + else: + lr = lr_scheduler.get_last_lr()[0] + if torch.is_tensor(lr): + lr = lr.item() lrs.append(lr) -- cgit v1.2.3-54-g00ecf