From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- data/csv.py | 22 +--------------------- data/keywords.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 21 deletions(-) create mode 100644 data/keywords.py (limited to 'data') diff --git a/data/csv.py b/data/csv.py index a60733a..d1f3054 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,7 +1,6 @@ import math import torch import json -import numpy as np from pathlib import Path from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split @@ -9,32 +8,13 @@ from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable from models.clip.prompt import PromptProcessor +from data.keywords import prompt_to_keywords, keywords_to_prompt def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: - if dropout != 0: - prompt = [keyword for keyword in prompt if np.random.random() > dropout] - if shuffle: - np.random.shuffle(prompt) - return ", ".join(prompt) - - -def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: - def expand_keyword(keyword: str) -> list[str]: - return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] - - return [ - kw - for keyword in prompt.split(", ") - for kw in expand_keyword(keyword) - if keyword != "" - ] - - class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py @@ -0,0 +1,21 @@ +import numpy as np + + +def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: + if dropout != 0: + prompt = [keyword for keyword in prompt if np.random.random() > dropout] + if shuffle: + np.random.shuffle(prompt) + return ", ".join(prompt) + + +def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: + def expand_keyword(keyword: str) -> list[str]: + return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + + return [ + kw + for keyword in prompt.split(", ") + for kw in expand_keyword(keyword) + if keyword != "" + ] -- cgit v1.2.3-70-g09d2