diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py index 9c3c3f8..20ac992 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -7,7 +7,7 @@ import pytorch_lightning as pl | |||
7 | from PIL import Image | 7 | from PIL import Image |
8 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
9 | from torchvision import transforms | 9 | from torchvision import transforms |
10 | from typing import Dict, NamedTuple, List, Optional, Union | 10 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
11 | 11 | ||
12 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
13 | 13 | ||
@@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
57 | template_key: str = "template", | 57 | template_key: str = "template", |
58 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
59 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
60 | keyword_filter: list[str] = [], | 60 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
61 | collate_fn=None, | 61 | collate_fn=None, |
62 | num_workers: int = 0 | 62 | num_workers: int = 0 |
63 | ): | 63 | ): |
@@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
84 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
85 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
86 | self.generator = generator | 86 | self.generator = generator |
87 | self.keyword_filter = keyword_filter | 87 | self.filter = filter |
88 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
89 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
90 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
@@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
105 | ] | 105 | ] |
106 | 106 | ||
107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: |
108 | if len(self.keyword_filter) == 0: | 108 | if self.filter is None: |
109 | return items | 109 | return items |
110 | 110 | ||
111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | 111 | return [item for item in items if self.filter(item)] |
112 | 112 | ||
113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: |
114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) |