diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index e1b92c1..818fcd9 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -186,6 +186,7 @@ class VlpnDataModule(): | |||
186 | dropout: float = 0, | 186 | dropout: float = 0, |
187 | shuffle: bool = False, | 187 | shuffle: bool = False, |
188 | interpolation: str = "bicubic", | 188 | interpolation: str = "bicubic", |
189 | color_jitter: bool = True, | ||
189 | template_key: str = "template", | 190 | template_key: str = "template", |
190 | placeholder_tokens: list[str] = [], | 191 | placeholder_tokens: list[str] = [], |
191 | valid_set_size: Optional[int] = None, | 192 | valid_set_size: Optional[int] = None, |
@@ -219,6 +220,7 @@ class VlpnDataModule(): | |||
219 | self.shuffle = shuffle | 220 | self.shuffle = shuffle |
220 | self.template_key = template_key | 221 | self.template_key = template_key |
221 | self.interpolation = interpolation | 222 | self.interpolation = interpolation |
223 | self.color_jitter = color_jitter | ||
222 | self.valid_set_size = valid_set_size | 224 | self.valid_set_size = valid_set_size |
223 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size | 225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
224 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | 226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size |
@@ -323,7 +325,7 @@ class VlpnDataModule(): | |||
323 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 325 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
324 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 326 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
325 | batch_size=self.batch_size, fill_batch=True, generator=generator, | 327 | batch_size=self.batch_size, fill_batch=True, generator=generator, |
326 | size=self.size, interpolation=self.interpolation, | 328 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
327 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 329 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
328 | ) | 330 | ) |
329 | 331 | ||
@@ -343,7 +345,7 @@ class VlpnDataModule(): | |||
343 | num_buckets=self.num_buckets, progressive_buckets=True, | 345 | num_buckets=self.num_buckets, progressive_buckets=True, |
344 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 346 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
345 | batch_size=self.batch_size, generator=generator, | 347 | batch_size=self.batch_size, generator=generator, |
346 | size=self.size, interpolation=self.interpolation, | 348 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
347 | ) | 349 | ) |
348 | 350 | ||
349 | self.val_dataloader = DataLoader( | 351 | self.val_dataloader = DataLoader( |
@@ -370,6 +372,7 @@ class VlpnDataset(IterableDataset): | |||
370 | dropout: float = 0, | 372 | dropout: float = 0, |
371 | shuffle: bool = False, | 373 | shuffle: bool = False, |
372 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
375 | color_jitter: bool = True, | ||
373 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
374 | ): | 377 | ): |
375 | self.items = items | 378 | self.items = items |
@@ -382,6 +385,7 @@ class VlpnDataset(IterableDataset): | |||
382 | self.dropout = dropout | 385 | self.dropout = dropout |
383 | self.shuffle = shuffle | 386 | self.shuffle = shuffle |
384 | self.interpolation = interpolations[interpolation] | 387 | self.interpolation = interpolations[interpolation] |
388 | self.color_jitter = color_jitter | ||
385 | self.generator = generator | 389 | self.generator = generator |
386 | 390 | ||
387 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 391 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
@@ -446,15 +450,20 @@ class VlpnDataset(IterableDataset): | |||
446 | width = int(self.size * ratio) if ratio > 1 else self.size | 450 | width = int(self.size * ratio) if ratio > 1 else self.size |
447 | height = int(self.size / ratio) if ratio < 1 else self.size | 451 | height = int(self.size / ratio) if ratio < 1 else self.size |
448 | 452 | ||
449 | image_transforms = transforms.Compose( | 453 | image_transforms = [ |
450 | [ | 454 | transforms.Resize(self.size, interpolation=self.interpolation), |
451 | transforms.Resize(self.size, interpolation=self.interpolation), | 455 | transforms.RandomCrop((height, width)), |
452 | transforms.RandomCrop((height, width)), | 456 | transforms.RandomHorizontalFlip(), |
453 | transforms.RandomHorizontalFlip(), | 457 | ] |
454 | transforms.ToTensor(), | 458 | if self.color_jitter: |
455 | transforms.Normalize([0.5], [0.5]), | 459 | image_transforms += [ |
460 | transforms.ColorJitter(0.2, 0.1), | ||
456 | ] | 461 | ] |
457 | ) | 462 | image_transforms += [ |
463 | transforms.ToTensor(), | ||
464 | transforms.Normalize([0.5], [0.5]), | ||
465 | ] | ||
466 | image_transforms = transforms.Compose(image_transforms) | ||
458 | 467 | ||
459 | continue | 468 | continue |
460 | 469 | ||