diff options
Diffstat (limited to 'data/keywords.py')
-rw-r--r-- | data/keywords.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py | |||
@@ -0,0 +1,21 @@ | |||
1 | import numpy as np | ||
2 | |||
3 | |||
4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | ||
5 | if dropout != 0: | ||
6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | ||
7 | if shuffle: | ||
8 | np.random.shuffle(prompt) | ||
9 | return ", ".join(prompt) | ||
10 | |||
11 | |||
12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | ||
13 | def expand_keyword(keyword: str) -> list[str]: | ||
14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
15 | |||
16 | return [ | ||
17 | kw | ||
18 | for keyword in prompt.split(", ") | ||
19 | for kw in expand_keyword(keyword) | ||
20 | if keyword != "" | ||
21 | ] | ||