blob: 629006d2362248b90d910e5f18b72b75ac42d67f (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
import numpy as np
def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str:
if dropout != 0:
keywords = [keyword for keyword in keywords if np.random.random() > dropout]
else:
keywords = keywords.copy()
keywords += undroppable_keywords
if shuffle:
np.random.shuffle(keywords)
return ", ".join(keywords)
def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
def expand_keyword(keyword: str) -> list[str]:
return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
return [
kw
for keyword in s.split(", ")
for kw in expand_keyword(keyword)
if keyword != ""
]
|