diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 15:21:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 15:21:37 +0100 |
commit | 8373527ceed371338803a54721a48212bcad565f (patch) | |
tree | 501ba1f5dc7ef98f104ab52565c9164fd3d1bf32 | |
parent | Cleanup (diff) | |
download | textual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.tar.gz textual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.tar.bz2 textual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.zip |
Fixed aspect ratio bucketing
-rw-r--r-- | data/csv.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index eaef5e6..7527b7d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -70,7 +70,9 @@ def generate_buckets( | |||
70 | mask = torch.logical_and(buckets <= 1, buckets >= ratio) | 70 | mask = torch.logical_and(buckets <= 1, buckets >= ratio) |
71 | 71 | ||
72 | if not progressive_buckets: | 72 | if not progressive_buckets: |
73 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | 73 | inf = torch.zeros_like(buckets) |
74 | inf[~mask] = math.inf | ||
75 | mask = (buckets + inf - ratio).abs().argmin() | ||
74 | 76 | ||
75 | indices = bucket_indices[mask] | 77 | indices = bucket_indices[mask] |
76 | 78 | ||
@@ -321,8 +323,8 @@ class VlpnDataset(IterableDataset): | |||
321 | 323 | ||
322 | bucket = self.bucket_assignments[mask][0] | 324 | bucket = self.bucket_assignments[mask][0] |
323 | ratio = self.buckets[bucket] | 325 | ratio = self.buckets[bucket] |
324 | width = self.size * ratio if ratio > 1 else self.size | 326 | width = int(self.size * ratio) if ratio > 1 else self.size |
325 | height = self.size / ratio if ratio < 1 else self.size | 327 | height = int(self.size / ratio) if ratio < 1 else self.size |
326 | 328 | ||
327 | image_transforms = transforms.Compose( | 329 | image_transforms = transforms.Compose( |
328 | [ | 330 | [ |