summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/data/csv.py b/data/csv.py
index 9c3c3f8..20ac992 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
7from PIL import Image 7from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split 8from torch.utils.data import Dataset, DataLoader, random_split
9from torchvision import transforms 9from torchvision import transforms
10from typing import Dict, NamedTuple, List, Optional, Union 10from typing import Dict, NamedTuple, List, Optional, Union, Callable
11 11
12from models.clip.prompt import PromptProcessor 12from models.clip.prompt import PromptProcessor
13 13
@@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule):
57 template_key: str = "template", 57 template_key: str = "template",
58 valid_set_size: Optional[int] = None, 58 valid_set_size: Optional[int] = None,
59 generator: Optional[torch.Generator] = None, 59 generator: Optional[torch.Generator] = None,
60 keyword_filter: list[str] = [], 60 filter: Optional[Callable[[CSVDataItem], bool]] = None,
61 collate_fn=None, 61 collate_fn=None,
62 num_workers: int = 0 62 num_workers: int = 0
63 ): 63 ):
@@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule):
84 self.interpolation = interpolation 84 self.interpolation = interpolation
85 self.valid_set_size = valid_set_size 85 self.valid_set_size = valid_set_size
86 self.generator = generator 86 self.generator = generator
87 self.keyword_filter = keyword_filter 87 self.filter = filter
88 self.collate_fn = collate_fn 88 self.collate_fn = collate_fn
89 self.num_workers = num_workers 89 self.num_workers = num_workers
90 self.batch_size = batch_size 90 self.batch_size = batch_size
@@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule):
105 ] 105 ]
106 106
107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: 107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]:
108 if len(self.keyword_filter) == 0: 108 if self.filter is None:
109 return items 109 return items
110 110
111 return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] 111 return [item for item in items if self.filter(item)]
112 112
113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1) 114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1)