summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/data/keywords.py b/data/keywords.py
index 629006d..8632d67 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,14 +1,23 @@
1from typing import Optional
2
1import numpy as np 3import numpy as np
2 4
3 5
4def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: 6def keywords_to_str(
7 keywords: list[str],
8 undroppable_keywords: list[str] = [],
9 dropout: float = 0,
10 shuffle: bool = False,
11 npgenerator: Optional[np.random.Generator] = None
12) -> str:
5 if dropout != 0: 13 if dropout != 0:
6 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 14 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 else: 15 else:
8 keywords = keywords.copy() 16 keywords = keywords.copy()
9 keywords += undroppable_keywords 17 keywords += undroppable_keywords
10 if shuffle: 18 if shuffle:
11 np.random.shuffle(keywords) 19 npgenerator = npgenerator or np.random.default_rng()
20 npgenerator.shuffle(keywords)
12 return ", ".join(keywords) 21 return ", ".join(keywords)
13 22
14 23