From ecb12378da48fc3a17539d5cc33edc561cf8a426 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 20:33:04 +0100 Subject: Improved aspect ratio bucketing --- data/csv.py | 22 +++++++++++++++++++--- train_dreambooth.py | 27 +++++++++++++++++++++++++++ train_ti.py | 16 +++++++++++++++- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py @@ -44,18 +44,25 @@ def generate_buckets( items: list[str], base_size: int, step_size: int = 64, + max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, return_tensor: bool = True ): + if max_pixels is None: + max_pixels = (base_size + step_size) ** 2 + + max_pixels = max(max_pixels, base_size * base_size) + bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): - s = base_size + i * step_size - buckets.append(s / base_size) - buckets.append(base_size / s) + long_side = base_size + i * step_size + short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) + buckets.append(long_side / short_side) + buckets.append(short_side / long_side) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) @@ -110,6 +117,8 @@ class VlpnDataModule(): num_class_images: int = 1, size: int = 768, num_buckets: int = 0, + bucket_step_size: int = 64, + max_pixels_per_bucket: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, interpolation: str = "bicubic", @@ -135,6 +144,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size self.num_buckets = num_buckets + self.bucket_step_size = bucket_step_size + self.max_pixels_per_bucket = max_pixels_per_bucket self.progressive_buckets = progressive_buckets self.dropout = dropout self.template_key = template_key @@ -223,6 +234,7 @@ class VlpnDataModule(): train_dataset = VlpnDataset( self.data_train, self.prompt_processor, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, + bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, @@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): items: list[VlpnDataItem], prompt_processor: PromptProcessor, num_buckets: int = 1, + bucket_step_size: int = 64, + max_pixels_per_bucket: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, num_class_images: int = 0, @@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in items], base_size=size, + step_size=bucket_step_size, num_buckets=num_buckets, + max_pixels=max_pixels_per_bucket, progressive_buckets=progressive_buckets, ) diff --git a/train_dreambooth.py b/train_dreambooth.py index 79eede6..d396249 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -103,6 +103,29 @@ def parse_args(): default=999999, help="Number of epochs the text encoder will be trained." ) + parser.add_argument( + "--num_buckets", + type=int, + default=4, + help="Number of aspect ratio buckets in either direction.", + ) + parser.add_argument( + "--progressive_buckets", + action="store_true", + help="Include images in smaller buckets as well.", + ) + parser.add_argument( + "--bucket_step_size", + type=int, + default=64, + help="Step size between buckets.", + ) + parser.add_argument( + "--bucket_max_pixels", + type=int, + default=None, + help="Maximum pixels per bucket.", + ) parser.add_argument( "--tag_dropout", type=float, @@ -734,6 +757,10 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, diff --git a/train_ti.py b/train_ti.py index 323ef10..eb0b8b6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -143,13 +143,25 @@ def parse_args(): "--num_buckets", type=int, default=4, - help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", + help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) + parser.add_argument( + "--bucket_step_size", + type=int, + default=64, + help="Step size between buckets.", + ) + parser.add_argument( + "--bucket_max_pixels", + type=int, + default=None, + help="Maximum pixels per bucket.", + ) parser.add_argument( "--tag_dropout", type=float, @@ -718,6 +730,8 @@ def main(): size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, -- cgit v1.2.3-54-g00ecf