From e3669927b47b5367a3348d30c4b318da84af661d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Apr 2023 11:14:03 +0200 Subject: Update dataset format: Separate prompt and keywords --- data/keywords.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'data/keywords.py') diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -1,21 +1,21 @@ import numpy as np -def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: +def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: if dropout != 0: - prompt = [keyword for keyword in prompt if np.random.random() > dropout] + keywords = [keyword for keyword in keywords if np.random.random() > dropout] if shuffle: - np.random.shuffle(prompt) - return ", ".join(prompt) + np.random.shuffle(keywords) + return ", ".join(keywords) -def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: +def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: def expand_keyword(keyword: str) -> list[str]: return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] return [ kw - for keyword in prompt.split(", ") + for keyword in s.split(", ") for kw in expand_keyword(keyword) if keyword != "" ] -- cgit v1.2.3-54-g00ecf