diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-03 07:19:06 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-03 07:19:06 +0200 | 
| commit | 2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea (patch) | |
| tree | 93c017967359d71fb44965452f30ef1e419fdff5 /data/keywords.py | |
| parent | Update dataset format: Separate prompt and keywords (diff) | |
| download | textual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.tar.gz textual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.tar.bz2 textual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.zip | |
Fix memory leak
Diffstat (limited to 'data/keywords.py')
| -rw-r--r-- | data/keywords.py | 5 | 
1 files changed, 4 insertions, 1 deletions
| diff --git a/data/keywords.py b/data/keywords.py index 7385809..629006d 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -1,9 +1,12 @@ | |||
| 1 | import numpy as np | 1 | import numpy as np | 
| 2 | 2 | ||
| 3 | 3 | ||
| 4 | def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 4 | def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: | 
| 5 | if dropout != 0: | 5 | if dropout != 0: | 
| 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 
| 7 | else: | ||
| 8 | keywords = keywords.copy() | ||
| 9 | keywords += undroppable_keywords | ||
| 7 | if shuffle: | 10 | if shuffle: | 
| 8 | np.random.shuffle(keywords) | 11 | np.random.shuffle(keywords) | 
| 9 | return ", ".join(keywords) | 12 | return ", ".join(keywords) | 
