From c16d5ba5ab1445386efcb87325557d4c784891e6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 14:42:40 +0100 Subject: Cleanup --- data/csv.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 289a64d..eaef5e6 100644 --- a/data/csv.py +++ b/data/csv.py @@ -40,15 +40,22 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): +def generate_buckets( + items: list[str], + base_size: int, + step_size: int = 64, + num_buckets: int = 4, + progressive_buckets: bool = False, + return_tensor: bool = True +): bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): - s = size + i * 64 - buckets.append(s / size) - buckets.append(size / s) + s = base_size + i * step_size + buckets.append(s / base_size) + buckets.append(base_size / s) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) @@ -58,9 +65,9 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ ratio = image.width / image.height if ratio >= 1: - mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) + mask = torch.logical_and(buckets >= 1, buckets <= ratio) else: - mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) + mask = torch.logical_and(buckets <= 1, buckets >= ratio) if not progressive_buckets: mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() @@ -73,7 +80,13 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ bucket_items += [i] * len(indices) bucket_assignments += indices - return buckets.tolist(), bucket_items, bucket_assignments + if return_tensor: + bucket_items = torch.tensor(bucket_items) + bucket_assignments = torch.tensor(bucket_assignments) + else: + buckets = buckets.tolist() + + return buckets, bucket_items, bucket_assignments class VlpnDataItem(NamedTuple): @@ -256,17 +269,14 @@ class VlpnDataset(IterableDataset): self.interpolation = interpolations[interpolation] self.generator = generator - buckets, bucket_items, bucket_assignments = generate_buckets( + self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in items], - size, - num_buckets, - progressive_buckets + base_size=size, + num_buckets=num_buckets, + progressive_buckets=progressive_buckets, ) - self.buckets = torch.tensor(buckets) - self.bucket_items = torch.tensor(bucket_items) - self.bucket_assignments = torch.tensor(bucket_assignments) - self.bucket_item_range = torch.arange(len(bucket_items)) + self.bucket_item_range = torch.arange(len(self.bucket_items)) self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() -- cgit v1.2.3-70-g09d2