diff options
-rw-r--r-- | data/csv.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index eaef5e6..7527b7d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -70,7 +70,9 @@ def generate_buckets( | |||
70 | mask = torch.logical_and(buckets <= 1, buckets >= ratio) | 70 | mask = torch.logical_and(buckets <= 1, buckets >= ratio) |
71 | 71 | ||
72 | if not progressive_buckets: | 72 | if not progressive_buckets: |
73 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | 73 | inf = torch.zeros_like(buckets) |
74 | inf[~mask] = math.inf | ||
75 | mask = (buckets + inf - ratio).abs().argmin() | ||
74 | 76 | ||
75 | indices = bucket_indices[mask] | 77 | indices = bucket_indices[mask] |
76 | 78 | ||
@@ -321,8 +323,8 @@ class VlpnDataset(IterableDataset): | |||
321 | 323 | ||
322 | bucket = self.bucket_assignments[mask][0] | 324 | bucket = self.bucket_assignments[mask][0] |
323 | ratio = self.buckets[bucket] | 325 | ratio = self.buckets[bucket] |
324 | width = self.size * ratio if ratio > 1 else self.size | 326 | width = int(self.size * ratio) if ratio > 1 else self.size |
325 | height = self.size / ratio if ratio < 1 else self.size | 327 | height = int(self.size / ratio) if ratio < 1 else self.size |
326 | 328 | ||
327 | image_transforms = transforms.Compose( | 329 | image_transforms = transforms.Compose( |
328 | [ | 330 | [ |