From f4f996681ca340e940315ca0ebc162c655904a7d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 16:02:04 +0200 Subject: Add color jitter --- data/csv.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index e1b92c1..818fcd9 100644 --- a/data/csv.py +++ b/data/csv.py @@ -186,6 +186,7 @@ class VlpnDataModule(): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", + color_jitter: bool = True, template_key: str = "template", placeholder_tokens: list[str] = [], valid_set_size: Optional[int] = None, @@ -219,6 +220,7 @@ class VlpnDataModule(): self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation + self.color_jitter = color_jitter self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size @@ -323,7 +325,7 @@ class VlpnDataModule(): num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, fill_batch=True, generator=generator, - size=self.size, interpolation=self.interpolation, + size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) @@ -343,7 +345,7 @@ class VlpnDataModule(): num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, - size=self.size, interpolation=self.interpolation, + size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, ) self.val_dataloader = DataLoader( @@ -370,6 +372,7 @@ class VlpnDataset(IterableDataset): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", + color_jitter: bool = True, generator: Optional[torch.Generator] = None, ): self.items = items @@ -382,6 +385,7 @@ class VlpnDataset(IterableDataset): self.dropout = dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] + self.color_jitter = color_jitter self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( @@ -446,15 +450,20 @@ class VlpnDataset(IterableDataset): width = int(self.size * ratio) if ratio > 1 else self.size height = int(self.size / ratio) if ratio < 1 else self.size - image_transforms = transforms.Compose( - [ - transforms.Resize(self.size, interpolation=self.interpolation), - transforms.RandomCrop((height, width)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + image_transforms = [ + transforms.Resize(self.size, interpolation=self.interpolation), + transforms.RandomCrop((height, width)), + transforms.RandomHorizontalFlip(), + ] + if self.color_jitter: + image_transforms += [ + transforms.ColorJitter(0.2, 0.1), ] - ) + image_transforms += [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + image_transforms = transforms.Compose(image_transforms) continue -- cgit v1.2.3-70-g09d2