diff options
| -rw-r--r-- | data/csv.py | 5 | ||||
| -rw-r--r-- | data/keywords.py | 5 |
2 files changed, 5 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py index d0ac317..e1b92c1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -142,10 +142,7 @@ class VlpnDataItem(NamedTuple): | |||
| 142 | collection: list[str] | 142 | collection: list[str] |
| 143 | 143 | ||
| 144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): | 144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): |
| 145 | prompt = self.prompt | 145 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) |
| 146 | if len(self.keywords): | ||
| 147 | prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle) | ||
| 148 | return prompt | ||
| 149 | 146 | ||
| 150 | 147 | ||
| 151 | def keyword_filter( | 148 | def keyword_filter( |
diff --git a/data/keywords.py b/data/keywords.py index 7385809..629006d 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -1,9 +1,12 @@ | |||
| 1 | import numpy as np | 1 | import numpy as np |
| 2 | 2 | ||
| 3 | 3 | ||
| 4 | def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str: | 4 | def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: |
| 5 | if dropout != 0: | 5 | if dropout != 0: |
| 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| 7 | else: | ||
| 8 | keywords = keywords.copy() | ||
| 9 | keywords += undroppable_keywords | ||
| 7 | if shuffle: | 10 | if shuffle: |
| 8 | np.random.shuffle(keywords) | 11 | np.random.shuffle(keywords) |
| 9 | return ", ".join(keywords) | 12 | return ", ".join(keywords) |
