summaryrefslogtreecommitdiffstats
path: root/data/keywords.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 07:19:06 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 07:19:06 +0200
commit2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea (patch)
tree93c017967359d71fb44965452f30ef1e419fdff5 /data/keywords.py
parentUpdate dataset format: Separate prompt and keywords (diff)
downloadtextual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.tar.gz
textual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.tar.bz2
textual-inversion-diff-2bf80e40762a3e7a62ebcc89640f9a6deda2d3ea.zip
Fix memory leak
Diffstat (limited to 'data/keywords.py')
-rw-r--r--data/keywords.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/data/keywords.py b/data/keywords.py
index 7385809..629006d 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,9 +1,12 @@
1import numpy as np 1import numpy as np
2 2
3 3
4def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: 4def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0: 5 if dropout != 0:
6 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 6 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 else:
8 keywords = keywords.copy()
9 keywords += undroppable_keywords
7 if shuffle: 10 if shuffle:
8 np.random.shuffle(keywords) 11 np.random.shuffle(keywords)
9 return ", ".join(keywords) 12 return ", ".join(keywords)