summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py5
-rw-r--r--data/keywords.py5
2 files changed, 5 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py
index d0ac317..e1b92c1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -142,10 +142,7 @@ class VlpnDataItem(NamedTuple):
142 collection: list[str] 142 collection: list[str]
143 143
144 def full_prompt(self, dropout: float = 0, shuffle: bool = False): 144 def full_prompt(self, dropout: float = 0, shuffle: bool = False):
145 prompt = self.prompt 145 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle)
146 if len(self.keywords):
147 prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle)
148 return prompt
149 146
150 147
151def keyword_filter( 148def keyword_filter(
diff --git a/data/keywords.py b/data/keywords.py
index 7385809..629006d 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,9 +1,12 @@
1import numpy as np 1import numpy as np
2 2
3 3
4def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: 4def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0: 5 if dropout != 0:
6 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 6 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 else:
8 keywords = keywords.copy()
9 keywords += undroppable_keywords
7 if shuffle: 10 if shuffle:
8 np.random.shuffle(keywords) 11 np.random.shuffle(keywords)
9 return ", ".join(keywords) 12 return ", ".join(keywords)