From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- data/keywords.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'data/keywords.py') diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -8,7 +8,7 @@ def keywords_to_str( undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False, - npgenerator: Optional[np.random.Generator] = None + npgenerator: Optional[np.random.Generator] = None, ) -> str: if dropout != 0: keywords = [keyword for keyword in keywords if np.random.random() > dropout] @@ -23,7 +23,11 @@ def keywords_to_str( def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: def expand_keyword(keyword: str) -> list[str]: - return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + return ( + [keyword] + expansions[keyword].split(", ") + if keyword in expansions + else [keyword] + ) return [ kw -- cgit v1.2.3-54-g00ecf