From 1a0161f345191d78a19eec829f9d73b2c2c72f94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 09:44:12 +0200 Subject: Update --- data/keywords.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'data/keywords.py') diff --git a/data/keywords.py b/data/keywords.py index 629006d..8632d67 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -1,14 +1,23 @@ +from typing import Optional + import numpy as np -def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: +def keywords_to_str( + keywords: list[str], + undroppable_keywords: list[str] = [], + dropout: float = 0, + shuffle: bool = False, + npgenerator: Optional[np.random.Generator] = None +) -> str: if dropout != 0: keywords = [keyword for keyword in keywords if np.random.random() > dropout] else: keywords = keywords.copy() keywords += undroppable_keywords if shuffle: - np.random.shuffle(keywords) + npgenerator = npgenerator or np.random.default_rng() + npgenerator.shuffle(keywords) return ", ".join(keywords) -- cgit v1.2.3-54-g00ecf