summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
blob: 8632d67c1ceeaa176b17e35f9b4ff6635bcb83b3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Optional

import numpy as np


def keywords_to_str(
    keywords: list[str],
    undroppable_keywords: list[str] = [],
    dropout: float = 0,
    shuffle: bool = False,
    npgenerator: Optional[np.random.Generator] = None
) -> str:
    if dropout != 0:
        keywords = [keyword for keyword in keywords if np.random.random() > dropout]
    else:
        keywords = keywords.copy()
    keywords += undroppable_keywords
    if shuffle:
        npgenerator = npgenerator or np.random.default_rng()
        npgenerator.shuffle(keywords)
    return ", ".join(keywords)


def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
    def expand_keyword(keyword: str) -> list[str]:
        return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]

    return [
        kw
        for keyword in s.split(", ")
        for kw in expand_keyword(keyword)
        if keyword != ""
    ]