summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py45
1 files changed, 27 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index e901ab4..c505230 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -165,19 +165,27 @@ class CSVDataModule():
165 self.data_val = self.pad_items(data_val) 165 self.data_val = self.pad_items(data_val)
166 166
167 def setup(self, stage=None): 167 def setup(self, stage=None):
168 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 168 train_dataset = CSVDataset(
169 num_class_images=self.num_class_images, 169 self.data_train, self.prompt_processor, batch_size=self.batch_size,
170 size=self.size, interpolation=self.interpolation, 170 num_class_images=self.num_class_images,
171 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) 171 size=self.size, interpolation=self.interpolation,
172 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 172 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout
173 size=self.size, interpolation=self.interpolation, 173 )
174 center_crop=self.center_crop) 174 val_dataset = CSVDataset(
175 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 175 self.data_val, self.prompt_processor, batch_size=self.batch_size,
176 shuffle=True, pin_memory=True, collate_fn=self.collate_fn, 176 size=self.size, interpolation=self.interpolation,
177 num_workers=self.num_workers) 177 center_crop=self.center_crop
178 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 178 )
179 pin_memory=True, collate_fn=self.collate_fn, 179 self.train_dataloader_ = DataLoader(
180 num_workers=self.num_workers) 180 train_dataset, batch_size=self.batch_size,
181 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
182 num_workers=self.num_workers
183 )
184 self.val_dataloader_ = DataLoader(
185 val_dataset, batch_size=self.batch_size,
186 pin_memory=True, collate_fn=self.collate_fn,
187 num_workers=self.num_workers
188 )
181 189
182 def train_dataloader(self): 190 def train_dataloader(self):
183 return self.train_dataloader_ 191 return self.train_dataloader_
@@ -210,11 +218,12 @@ class CSVDataset(Dataset):
210 self.num_instance_images = len(self.data) 218 self.num_instance_images = len(self.data)
211 self._length = self.num_instance_images * repeats 219 self._length = self.num_instance_images * repeats
212 220
213 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, 221 self.interpolation = {
214 "bilinear": transforms.InterpolationMode.BILINEAR, 222 "linear": transforms.InterpolationMode.NEAREST,
215 "bicubic": transforms.InterpolationMode.BICUBIC, 223 "bilinear": transforms.InterpolationMode.BILINEAR,
216 "lanczos": transforms.InterpolationMode.LANCZOS, 224 "bicubic": transforms.InterpolationMode.BICUBIC,
217 }[interpolation] 225 "lanczos": transforms.InterpolationMode.LANCZOS,
226 }[interpolation]
218 self.image_transforms = transforms.Compose( 227 self.image_transforms = transforms.Compose(
219 [ 228 [
220 transforms.Resize(size, interpolation=self.interpolation), 229 transforms.Resize(size, interpolation=self.interpolation),