diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 07:39:28 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 07:39:28 +0100 |
| commit | 29b79243d4ab3c5bc137e81b9dead8f57c45d5de (patch) | |
| tree | b0156dbfed7c9316033cf0aabb9ff28f93325f39 | |
| parent | Upstream patches (diff) | |
| download | textual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.tar.gz textual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.tar.bz2 textual-inversion-diff-29b79243d4ab3c5bc137e81b9dead8f57c45d5de.zip | |
Fix Textual Inversion dataset filtering
| -rw-r--r-- | train_ti.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index dbfe58c..5c0299e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -648,7 +648,11 @@ def main(): | |||
| 648 | weight_dtype = torch.bfloat16 | 648 | weight_dtype = torch.bfloat16 |
| 649 | 649 | ||
| 650 | def keyword_filter(item: CSVDataItem): | 650 | def keyword_filter(item: CSVDataItem): |
| 651 | return any(keyword in item.prompt for keyword in args.placeholder_token) | 651 | return any( |
| 652 | keyword in part | ||
| 653 | for keyword in args.placeholder_token | ||
| 654 | for part in item.prompt | ||
| 655 | ) | ||
| 652 | 656 | ||
| 653 | def collate_fn(examples): | 657 | def collate_fn(examples): |
| 654 | prompts = [example["prompts"] for example in examples] | 658 | prompts = [example["prompts"] for example in examples] |
