diff options
Diffstat (limited to 'data/keywords.py')
| -rw-r--r-- | data/keywords.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -8,7 +8,7 @@ def keywords_to_str( | |||
| 8 | undroppable_keywords: list[str] = [], | 8 | undroppable_keywords: list[str] = [], |
| 9 | dropout: float = 0, | 9 | dropout: float = 0, |
| 10 | shuffle: bool = False, | 10 | shuffle: bool = False, |
| 11 | npgenerator: Optional[np.random.Generator] = None | 11 | npgenerator: Optional[np.random.Generator] = None, |
| 12 | ) -> str: | 12 | ) -> str: |
| 13 | if dropout != 0: | 13 | if dropout != 0: |
| 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| @@ -23,7 +23,11 @@ def keywords_to_str( | |||
| 23 | 23 | ||
| 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: | 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
| 25 | def expand_keyword(keyword: str) -> list[str]: | 25 | def expand_keyword(keyword: str) -> list[str]: |
| 26 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 26 | return ( |
| 27 | [keyword] + expansions[keyword].split(", ") | ||
| 28 | if keyword in expansions | ||
| 29 | else [keyword] | ||
| 30 | ) | ||
| 27 | 31 | ||
| 28 | return [ | 32 | return [ |
| 29 | kw | 33 | kw |
