diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py index 9c3c3f8..20ac992 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -7,7 +7,7 @@ import pytorch_lightning as pl | |||
| 7 | from PIL import Image | 7 | from PIL import Image |
| 8 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
| 9 | from torchvision import transforms | 9 | from torchvision import transforms |
| 10 | from typing import Dict, NamedTuple, List, Optional, Union | 10 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
| 11 | 11 | ||
| 12 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
| 13 | 13 | ||
| @@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 57 | template_key: str = "template", | 57 | template_key: str = "template", |
| 58 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
| 59 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
| 60 | keyword_filter: list[str] = [], | 60 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
| 61 | collate_fn=None, | 61 | collate_fn=None, |
| 62 | num_workers: int = 0 | 62 | num_workers: int = 0 |
| 63 | ): | 63 | ): |
| @@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 84 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
| 85 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
| 86 | self.generator = generator | 86 | self.generator = generator |
| 87 | self.keyword_filter = keyword_filter | 87 | self.filter = filter |
| 88 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
| 89 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
| 90 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
| @@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 105 | ] | 105 | ] |
| 106 | 106 | ||
| 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: |
| 108 | if len(self.keyword_filter) == 0: | 108 | if self.filter is None: |
| 109 | return items | 109 | return items |
| 110 | 110 | ||
| 111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | 111 | return [item for item in items if self.filter(item)] |
| 112 | 112 | ||
| 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: |
| 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) |
