summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
blob: 9e656f3312b5bf3de410b7d1899bed531177c4ab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np


def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
    if dropout != 0:
        prompt = [keyword for keyword in prompt if np.random.random() > dropout]
    if shuffle:
        np.random.shuffle(prompt)
    return ", ".join(prompt)


def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]:
    def expand_keyword(keyword: str) -> list[str]:
        return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]

    return [
        kw
        for keyword in prompt.split(", ")
        for kw in expand_keyword(keyword)
        if keyword != ""
    ]