diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index e901ab4..c505230 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -165,19 +165,27 @@ class CSVDataModule(): | |||
165 | self.data_val = self.pad_items(data_val) | 165 | self.data_val = self.pad_items(data_val) |
166 | 166 | ||
167 | def setup(self, stage=None): | 167 | def setup(self, stage=None): |
168 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 168 | train_dataset = CSVDataset( |
169 | num_class_images=self.num_class_images, | 169 | self.data_train, self.prompt_processor, batch_size=self.batch_size, |
170 | size=self.size, interpolation=self.interpolation, | 170 | num_class_images=self.num_class_images, |
171 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 171 | size=self.size, interpolation=self.interpolation, |
172 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 172 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout |
173 | size=self.size, interpolation=self.interpolation, | 173 | ) |
174 | center_crop=self.center_crop) | 174 | val_dataset = CSVDataset( |
175 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 175 | self.data_val, self.prompt_processor, batch_size=self.batch_size, |
176 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | 176 | size=self.size, interpolation=self.interpolation, |
177 | num_workers=self.num_workers) | 177 | center_crop=self.center_crop |
178 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 178 | ) |
179 | pin_memory=True, collate_fn=self.collate_fn, | 179 | self.train_dataloader_ = DataLoader( |
180 | num_workers=self.num_workers) | 180 | train_dataset, batch_size=self.batch_size, |
181 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
182 | num_workers=self.num_workers | ||
183 | ) | ||
184 | self.val_dataloader_ = DataLoader( | ||
185 | val_dataset, batch_size=self.batch_size, | ||
186 | pin_memory=True, collate_fn=self.collate_fn, | ||
187 | num_workers=self.num_workers | ||
188 | ) | ||
181 | 189 | ||
182 | def train_dataloader(self): | 190 | def train_dataloader(self): |
183 | return self.train_dataloader_ | 191 | return self.train_dataloader_ |
@@ -210,11 +218,12 @@ class CSVDataset(Dataset): | |||
210 | self.num_instance_images = len(self.data) | 218 | self.num_instance_images = len(self.data) |
211 | self._length = self.num_instance_images * repeats | 219 | self._length = self.num_instance_images * repeats |
212 | 220 | ||
213 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, | 221 | self.interpolation = { |
214 | "bilinear": transforms.InterpolationMode.BILINEAR, | 222 | "linear": transforms.InterpolationMode.NEAREST, |
215 | "bicubic": transforms.InterpolationMode.BICUBIC, | 223 | "bilinear": transforms.InterpolationMode.BILINEAR, |
216 | "lanczos": transforms.InterpolationMode.LANCZOS, | 224 | "bicubic": transforms.InterpolationMode.BICUBIC, |
217 | }[interpolation] | 225 | "lanczos": transforms.InterpolationMode.LANCZOS, |
226 | }[interpolation] | ||
218 | self.image_transforms = transforms.Compose( | 227 | self.image_transforms = transforms.Compose( |
219 | [ | 228 | [ |
220 | transforms.Resize(size, interpolation=self.interpolation), | 229 | transforms.Resize(size, interpolation=self.interpolation), |