From ecb12378da48fc3a17539d5cc33edc561cf8a426 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 20:33:04 +0100 Subject: Improved aspect ratio bucketing --- data/csv.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py @@ -44,18 +44,25 @@ def generate_buckets( items: list[str], base_size: int, step_size: int = 64, + max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, return_tensor: bool = True ): + if max_pixels is None: + max_pixels = (base_size + step_size) ** 2 + + max_pixels = max(max_pixels, base_size * base_size) + bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): - s = base_size + i * step_size - buckets.append(s / base_size) - buckets.append(base_size / s) + long_side = base_size + i * step_size + short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) + buckets.append(long_side / short_side) + buckets.append(short_side / long_side) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) @@ -110,6 +117,8 @@ class VlpnDataModule(): num_class_images: int = 1, size: int = 768, num_buckets: int = 0, + bucket_step_size: int = 64, + max_pixels_per_bucket: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, interpolation: str = "bicubic", @@ -135,6 +144,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size self.num_buckets = num_buckets + self.bucket_step_size = bucket_step_size + self.max_pixels_per_bucket = max_pixels_per_bucket self.progressive_buckets = progressive_buckets self.dropout = dropout self.template_key = template_key @@ -223,6 +234,7 @@ class VlpnDataModule(): train_dataset = VlpnDataset( self.data_train, self.prompt_processor, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, + bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, @@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): items: list[VlpnDataItem], prompt_processor: PromptProcessor, num_buckets: int = 1, + bucket_step_size: int = 64, + max_pixels_per_bucket: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, num_class_images: int = 0, @@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in items], base_size=size, + step_size=bucket_step_size, num_buckets=num_buckets, + max_pixels=max_pixels_per_bucket, progressive_buckets=progressive_buckets, ) -- cgit v1.2.3-54-g00ecf