summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-20 07:39:28 +0100
committerVolpeon <git@volpeon.ink>2022-12-20 07:39:28 +0100
commit29b79243d4ab3c5bc137e81b9dead8f57c45d5de (patch)
treeb0156dbfed7c9316033cf0aabb9ff28f93325f39 /train_ti.py
parentUpstream patches (diff)
downloadtextual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.tar.gz
textual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.tar.bz2
textual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.zip
Fix Textual Inversion dataset filtering
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index dbfe58c..5c0299e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -648,7 +648,11 @@ def main():
648 weight_dtype = torch.bfloat16 648 weight_dtype = torch.bfloat16
649 649
650 def keyword_filter(item: CSVDataItem): 650 def keyword_filter(item: CSVDataItem):
651 return any(keyword in item.prompt for keyword in args.placeholder_token) 651 return any(
652 keyword in part
653 for keyword in args.placeholder_token
654 for part in item.prompt
655 )
652 656
653 def collate_fn(examples): 657 def collate_fn(examples):
654 prompts = [example["prompts"] for example in examples] 658 prompts = [example["prompts"] for example in examples]