From 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 12:18:07 +0100 Subject: Various updates --- data/csv.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index e901ab4..c505230 100644 --- a/data/csv.py +++ b/data/csv.py @@ -165,19 +165,27 @@ class CSVDataModule(): self.data_val = self.pad_items(data_val) def setup(self, stage=None): - train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, - num_class_images=self.num_class_images, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) - val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers) + train_dataset = CSVDataset( + self.data_train, self.prompt_processor, batch_size=self.batch_size, + num_class_images=self.num_class_images, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout + ) + val_dataset = CSVDataset( + self.data_val, self.prompt_processor, batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop + ) + self.train_dataloader_ = DataLoader( + train_dataset, batch_size=self.batch_size, + shuffle=True, pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers + ) + self.val_dataloader_ = DataLoader( + val_dataset, batch_size=self.batch_size, + pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers + ) def train_dataloader(self): return self.train_dataloader_ @@ -210,11 +218,12 @@ class CSVDataset(Dataset): self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), -- cgit v1.2.3-54-g00ecf