summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py40
1 files changed, 25 insertions, 15 deletions
diff --git a/data/csv.py b/data/csv.py
index 289a64d..eaef5e6 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -40,15 +40,22 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]):
40 return {"content": prompt} if isinstance(prompt, str) else prompt 40 return {"content": prompt} if isinstance(prompt, str) else prompt
41 41
42 42
43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): 43def generate_buckets(
44 items: list[str],
45 base_size: int,
46 step_size: int = 64,
47 num_buckets: int = 4,
48 progressive_buckets: bool = False,
49 return_tensor: bool = True
50):
44 bucket_items: list[int] = [] 51 bucket_items: list[int] = []
45 bucket_assignments: list[int] = [] 52 bucket_assignments: list[int] = []
46 buckets = [1.0] 53 buckets = [1.0]
47 54
48 for i in range(1, num_buckets + 1): 55 for i in range(1, num_buckets + 1):
49 s = size + i * 64 56 s = base_size + i * step_size
50 buckets.append(s / size) 57 buckets.append(s / base_size)
51 buckets.append(size / s) 58 buckets.append(base_size / s)
52 59
53 buckets = torch.tensor(buckets) 60 buckets = torch.tensor(buckets)
54 bucket_indices = torch.arange(len(buckets)) 61 bucket_indices = torch.arange(len(buckets))
@@ -58,9 +65,9 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_
58 ratio = image.width / image.height 65 ratio = image.width / image.height
59 66
60 if ratio >= 1: 67 if ratio >= 1:
61 mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) 68 mask = torch.logical_and(buckets >= 1, buckets <= ratio)
62 else: 69 else:
63 mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) 70 mask = torch.logical_and(buckets <= 1, buckets >= ratio)
64 71
65 if not progressive_buckets: 72 if not progressive_buckets:
66 mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() 73 mask = (buckets + (~mask) * math.inf - ratio).abs().argmin()
@@ -73,7 +80,13 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_
73 bucket_items += [i] * len(indices) 80 bucket_items += [i] * len(indices)
74 bucket_assignments += indices 81 bucket_assignments += indices
75 82
76 return buckets.tolist(), bucket_items, bucket_assignments 83 if return_tensor:
84 bucket_items = torch.tensor(bucket_items)
85 bucket_assignments = torch.tensor(bucket_assignments)
86 else:
87 buckets = buckets.tolist()
88
89 return buckets, bucket_items, bucket_assignments
77 90
78 91
79class VlpnDataItem(NamedTuple): 92class VlpnDataItem(NamedTuple):
@@ -256,17 +269,14 @@ class VlpnDataset(IterableDataset):
256 self.interpolation = interpolations[interpolation] 269 self.interpolation = interpolations[interpolation]
257 self.generator = generator 270 self.generator = generator
258 271
259 buckets, bucket_items, bucket_assignments = generate_buckets( 272 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
260 [item.instance_image_path for item in items], 273 [item.instance_image_path for item in items],
261 size, 274 base_size=size,
262 num_buckets, 275 num_buckets=num_buckets,
263 progressive_buckets 276 progressive_buckets=progressive_buckets,
264 ) 277 )
265 278
266 self.buckets = torch.tensor(buckets) 279 self.bucket_item_range = torch.arange(len(self.bucket_items))
267 self.bucket_items = torch.tensor(bucket_items)
268 self.bucket_assignments = torch.tensor(bucket_assignments)
269 self.bucket_item_range = torch.arange(len(bucket_items))
270 280
271 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 281 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
272 282