diff options
Diffstat (limited to 'data/keywords.py')
-rw-r--r-- | data/keywords.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/data/keywords.py b/data/keywords.py index 9e656f3..7385809 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -1,21 +1,21 @@ | |||
1 | import numpy as np | 1 | import numpy as np |
2 | 2 | ||
3 | 3 | ||
4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 4 | def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: |
5 | if dropout != 0: | 5 | if dropout != 0: |
6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
7 | if shuffle: | 7 | if shuffle: |
8 | np.random.shuffle(prompt) | 8 | np.random.shuffle(keywords) |
9 | return ", ".join(prompt) | 9 | return ", ".join(keywords) |
10 | 10 | ||
11 | 11 | ||
12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | 12 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
13 | def expand_keyword(keyword: str) -> list[str]: | 13 | def expand_keyword(keyword: str) -> list[str]: |
14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] |
15 | 15 | ||
16 | return [ | 16 | return [ |
17 | kw | 17 | kw |
18 | for keyword in prompt.split(", ") | 18 | for keyword in s.split(", ") |
19 | for kw in expand_keyword(keyword) | 19 | for kw in expand_keyword(keyword) |
20 | if keyword != "" | 20 | if keyword != "" |
21 | ] | 21 | ] |