From 29b79243d4ab3c5bc137e81b9dead8f57c45d5de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 20 Dec 2022 07:39:28 +0100 Subject: Fix Textual Inversion dataset filtering --- train_ti.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_ti.py b/train_ti.py index dbfe58c..5c0299e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -648,7 +648,11 @@ def main(): weight_dtype = torch.bfloat16 def keyword_filter(item: CSVDataItem): - return any(keyword in item.prompt for keyword in args.placeholder_token) + return any( + keyword in part + for keyword in args.placeholder_token + for part in item.prompt + ) def collate_fn(examples): prompts = [example["prompts"] for example in examples] -- cgit v1.2.3-70-g09d2