From 279174a7a31f0fc6ed209e5b46901e50fe722c87 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 14 Dec 2022 09:43:45 +0100 Subject: More generic datset filter --- data/csv.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 9c3c3f8..20ac992 100644 --- a/data/csv.py +++ b/data/csv.py @@ -7,7 +7,7 @@ import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import Dict, NamedTuple, List, Optional, Union +from typing import Dict, NamedTuple, List, Optional, Union, Callable from models.clip.prompt import PromptProcessor @@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule): template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, - keyword_filter: list[str] = [], + filter: Optional[Callable[[CSVDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): @@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule): self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator - self.keyword_filter = keyword_filter + self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size @@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule): ] def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: - if len(self.keyword_filter) == 0: + if self.filter is None: return items - return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] + return [item for item in items if self.filter(item)] def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: image_multiplier = max(math.ceil(num_class_images / len(items)), 1) -- cgit v1.2.3-70-g09d2