summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 15:21:37 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 15:21:37 +0100
commit8373527ceed371338803a54721a48212bcad565f (patch)
tree501ba1f5dc7ef98f104ab52565c9164fd3d1bf32 /data
parentCleanup (diff)
downloadtextual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.tar.gz
textual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.tar.bz2
textual-inversion-diff-8373527ceed371338803a54721a48212bcad565f.zip
Fixed aspect ratio bucketing
Diffstat (limited to 'data')
-rw-r--r--data/csv.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index eaef5e6..7527b7d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -70,7 +70,9 @@ def generate_buckets(
70 mask = torch.logical_and(buckets <= 1, buckets >= ratio) 70 mask = torch.logical_and(buckets <= 1, buckets >= ratio)
71 71
72 if not progressive_buckets: 72 if not progressive_buckets:
73 mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() 73 inf = torch.zeros_like(buckets)
74 inf[~mask] = math.inf
75 mask = (buckets + inf - ratio).abs().argmin()
74 76
75 indices = bucket_indices[mask] 77 indices = bucket_indices[mask]
76 78
@@ -321,8 +323,8 @@ class VlpnDataset(IterableDataset):
321 323
322 bucket = self.bucket_assignments[mask][0] 324 bucket = self.bucket_assignments[mask][0]
323 ratio = self.buckets[bucket] 325 ratio = self.buckets[bucket]
324 width = self.size * ratio if ratio > 1 else self.size 326 width = int(self.size * ratio) if ratio > 1 else self.size
325 height = self.size / ratio if ratio < 1 else self.size 327 height = int(self.size / ratio) if ratio < 1 else self.size
326 328
327 image_transforms = transforms.Compose( 329 image_transforms = transforms.Compose(
328 [ 330 [