summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/data/keywords.py b/data/keywords.py
new file mode 100644
index 0000000..9e656f3
--- /dev/null
+++ b/data/keywords.py
@@ -0,0 +1,21 @@
1import numpy as np
2
3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
7 if shuffle:
8 np.random.shuffle(prompt)
9 return ", ".join(prompt)
10
11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15
16 return [
17 kw
18 for keyword in prompt.split(", ")
19 for kw in expand_keyword(keyword)
20 if keyword != ""
21 ]