From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 28 Mar 2023 16:24:22 +0200 Subject: Support num_train_steps arg again --- data/csv.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 9770bec..c00ea07 100644 --- a/data/csv.py +++ b/data/csv.py @@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple): def keyword_filter( placeholder_tokens: Optional[list[str]], - collection: Optional[list[str]], + collections: Optional[list[str]], exclude_collections: Optional[list[str]], item: VlpnDataItem ): @@ -152,11 +152,15 @@ def keyword_filter( for keyword in placeholder_tokens for part in item.prompt ) - cond2 = collection is None or collection in item.collection + cond2 = collections is None or any( + collection in item.collection + for collection in collections + ) cond3 = exclude_collections is None or not any( collection in item.collection for collection in exclude_collections ) + return cond1 and cond2 and cond3 -- cgit v1.2.3-54-g00ecf