summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/data/keywords.py b/data/keywords.py
index 7385809..629006d 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,9 +1,12 @@
1import numpy as np 1import numpy as np
2 2
3 3
4def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: 4def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0: 5 if dropout != 0:
6 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 6 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 else:
8 keywords = keywords.copy()
9 keywords += undroppable_keywords
7 if shuffle: 10 if shuffle:
8 np.random.shuffle(keywords) 11 np.random.shuffle(keywords)
9 return ", ".join(keywords) 12 return ", ".join(keywords)