diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 09:44:12 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 09:44:12 +0200 |
commit | 1a0161f345191d78a19eec829f9d73b2c2c72f94 (patch) | |
tree | 6d7bcc67672ebf26454b3254b4bd9d5ec7e64a16 /data/keywords.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.gz textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.bz2 textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.zip |
Update
Diffstat (limited to 'data/keywords.py')
-rw-r--r-- | data/keywords.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/data/keywords.py b/data/keywords.py index 629006d..8632d67 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -1,14 +1,23 @@ | |||
1 | from typing import Optional | ||
2 | |||
1 | import numpy as np | 3 | import numpy as np |
2 | 4 | ||
3 | 5 | ||
4 | def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: | 6 | def keywords_to_str( |
7 | keywords: list[str], | ||
8 | undroppable_keywords: list[str] = [], | ||
9 | dropout: float = 0, | ||
10 | shuffle: bool = False, | ||
11 | npgenerator: Optional[np.random.Generator] = None | ||
12 | ) -> str: | ||
5 | if dropout != 0: | 13 | if dropout != 0: |
6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
7 | else: | 15 | else: |
8 | keywords = keywords.copy() | 16 | keywords = keywords.copy() |
9 | keywords += undroppable_keywords | 17 | keywords += undroppable_keywords |
10 | if shuffle: | 18 | if shuffle: |
11 | np.random.shuffle(keywords) | 19 | npgenerator = npgenerator or np.random.default_rng() |
20 | npgenerator.shuffle(keywords) | ||
12 | return ", ".join(keywords) | 21 | return ", ".join(keywords) |
13 | 22 | ||
14 | 23 | ||