From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- data/keywords.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 data/keywords.py (limited to 'data/keywords.py') diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py @@ -0,0 +1,21 @@ +import numpy as np + + +def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: + if dropout != 0: + prompt = [keyword for keyword in prompt if np.random.random() > dropout] + if shuffle: + np.random.shuffle(prompt) + return ", ".join(prompt) + + +def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: + def expand_keyword(keyword: str) -> list[str]: + return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + + return [ + kw + for keyword in prompt.split(", ") + for kw in expand_keyword(keyword) + if keyword != "" + ] -- cgit v1.2.3-54-g00ecf