From 3c6ccadd3c12c54a1fa2280bce505a2dd511958a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 07:27:45 +0100 Subject: Implemented extended Dreambooth training --- data/csv.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 2a8115b..2b1e202 100644 --- a/data/csv.py +++ b/data/csv.py @@ -130,6 +130,25 @@ class VlpnDataItem(NamedTuple): collection: list[str] +def keyword_filter( + placeholder_tokens: Optional[list[str]], + collection: Optional[list[str]], + exclude_collections: Optional[list[str]], + item: VlpnDataItem +): + cond1 = placeholder_tokens is None or any( + keyword in part + for keyword in placeholder_tokens + for part in item.prompt + ) + cond2 = collection is None or collection in item.collection + cond3 = exclude_collections is None or not any( + collection in item.collection + for collection in exclude_collections + ) + return cond1 and cond2 and cond3 + + class VlpnDataModule(): def __init__( self, -- cgit v1.2.3-70-g09d2