From 8373527ceed371338803a54721a48212bcad565f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 15:21:37 +0100 Subject: Fixed aspect ratio bucketing --- data/csv.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/data/csv.py b/data/csv.py index eaef5e6..7527b7d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -70,7 +70,9 @@ def generate_buckets( mask = torch.logical_and(buckets <= 1, buckets >= ratio) if not progressive_buckets: - mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() + inf = torch.zeros_like(buckets) + inf[~mask] = math.inf + mask = (buckets + inf - ratio).abs().argmin() indices = bucket_indices[mask] @@ -321,8 +323,8 @@ class VlpnDataset(IterableDataset): bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] - width = self.size * ratio if ratio > 1 else self.size - height = self.size / ratio if ratio < 1 else self.size + width = int(self.size * ratio) if ratio > 1 else self.size + height = int(self.size / ratio) if ratio < 1 else self.size image_transforms = transforms.Compose( [ -- cgit v1.2.3-70-g09d2