summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
committerVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
commit179a45253a5b3712f32bd127f693a6bb810a9c17 (patch)
treeac9f1152d858089742e4f9ce79e0870e0f2b9a2d /data
parentFix TI (diff)
downloadtextual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.gz
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.bz2
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.zip
Support num_train_steps arg again
Diffstat (limited to 'data')
-rw-r--r--data/csv.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py
index 9770bec..c00ea07 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple):
143 143
144def keyword_filter( 144def keyword_filter(
145 placeholder_tokens: Optional[list[str]], 145 placeholder_tokens: Optional[list[str]],
146 collection: Optional[list[str]], 146 collections: Optional[list[str]],
147 exclude_collections: Optional[list[str]], 147 exclude_collections: Optional[list[str]],
148 item: VlpnDataItem 148 item: VlpnDataItem
149): 149):
@@ -152,11 +152,15 @@ def keyword_filter(
152 for keyword in placeholder_tokens 152 for keyword in placeholder_tokens
153 for part in item.prompt 153 for part in item.prompt
154 ) 154 )
155 cond2 = collection is None or collection in item.collection 155 cond2 = collections is None or any(
156 collection in item.collection
157 for collection in collections
158 )
156 cond3 = exclude_collections is None or not any( 159 cond3 = exclude_collections is None or not any(
157 collection in item.collection 160 collection in item.collection
158 for collection in exclude_collections 161 for collection in exclude_collections
159 ) 162 )
163
160 return cond1 and cond2 and cond3 164 return cond1 and cond2 and cond3
161 165
162 166