diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -44,18 +44,25 @@ def generate_buckets( | |||
| 44 | items: list[str], | 44 | items: list[str], |
| 45 | base_size: int, | 45 | base_size: int, |
| 46 | step_size: int = 64, | 46 | step_size: int = 64, |
| 47 | max_pixels: Optional[int] = None, | ||
| 47 | num_buckets: int = 4, | 48 | num_buckets: int = 4, |
| 48 | progressive_buckets: bool = False, | 49 | progressive_buckets: bool = False, |
| 49 | return_tensor: bool = True | 50 | return_tensor: bool = True |
| 50 | ): | 51 | ): |
| 52 | if max_pixels is None: | ||
| 53 | max_pixels = (base_size + step_size) ** 2 | ||
| 54 | |||
| 55 | max_pixels = max(max_pixels, base_size * base_size) | ||
| 56 | |||
| 51 | bucket_items: list[int] = [] | 57 | bucket_items: list[int] = [] |
| 52 | bucket_assignments: list[int] = [] | 58 | bucket_assignments: list[int] = [] |
| 53 | buckets = [1.0] | 59 | buckets = [1.0] |
| 54 | 60 | ||
| 55 | for i in range(1, num_buckets + 1): | 61 | for i in range(1, num_buckets + 1): |
| 56 | s = base_size + i * step_size | 62 | long_side = base_size + i * step_size |
| 57 | buckets.append(s / base_size) | 63 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) |
| 58 | buckets.append(base_size / s) | 64 | buckets.append(long_side / short_side) |
| 65 | buckets.append(short_side / long_side) | ||
| 59 | 66 | ||
| 60 | buckets = torch.tensor(buckets) | 67 | buckets = torch.tensor(buckets) |
| 61 | bucket_indices = torch.arange(len(buckets)) | 68 | bucket_indices = torch.arange(len(buckets)) |
| @@ -110,6 +117,8 @@ class VlpnDataModule(): | |||
| 110 | num_class_images: int = 1, | 117 | num_class_images: int = 1, |
| 111 | size: int = 768, | 118 | size: int = 768, |
| 112 | num_buckets: int = 0, | 119 | num_buckets: int = 0, |
| 120 | bucket_step_size: int = 64, | ||
| 121 | max_pixels_per_bucket: Optional[int] = None, | ||
| 113 | progressive_buckets: bool = False, | 122 | progressive_buckets: bool = False, |
| 114 | dropout: float = 0, | 123 | dropout: float = 0, |
| 115 | interpolation: str = "bicubic", | 124 | interpolation: str = "bicubic", |
| @@ -135,6 +144,8 @@ class VlpnDataModule(): | |||
| 135 | self.prompt_processor = prompt_processor | 144 | self.prompt_processor = prompt_processor |
| 136 | self.size = size | 145 | self.size = size |
| 137 | self.num_buckets = num_buckets | 146 | self.num_buckets = num_buckets |
| 147 | self.bucket_step_size = bucket_step_size | ||
| 148 | self.max_pixels_per_bucket = max_pixels_per_bucket | ||
| 138 | self.progressive_buckets = progressive_buckets | 149 | self.progressive_buckets = progressive_buckets |
| 139 | self.dropout = dropout | 150 | self.dropout = dropout |
| 140 | self.template_key = template_key | 151 | self.template_key = template_key |
| @@ -223,6 +234,7 @@ class VlpnDataModule(): | |||
| 223 | train_dataset = VlpnDataset( | 234 | train_dataset = VlpnDataset( |
| 224 | self.data_train, self.prompt_processor, | 235 | self.data_train, self.prompt_processor, |
| 225 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 237 | bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, | ||
| 226 | batch_size=self.batch_size, generator=generator, | 238 | batch_size=self.batch_size, generator=generator, |
| 227 | size=self.size, interpolation=self.interpolation, | 239 | size=self.size, interpolation=self.interpolation, |
| 228 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
| @@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): | |||
| 251 | items: list[VlpnDataItem], | 263 | items: list[VlpnDataItem], |
| 252 | prompt_processor: PromptProcessor, | 264 | prompt_processor: PromptProcessor, |
| 253 | num_buckets: int = 1, | 265 | num_buckets: int = 1, |
| 266 | bucket_step_size: int = 64, | ||
| 267 | max_pixels_per_bucket: Optional[int] = None, | ||
| 254 | progressive_buckets: bool = False, | 268 | progressive_buckets: bool = False, |
| 255 | batch_size: int = 1, | 269 | batch_size: int = 1, |
| 256 | num_class_images: int = 0, | 270 | num_class_images: int = 0, |
| @@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): | |||
| 274 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 288 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
| 275 | [item.instance_image_path for item in items], | 289 | [item.instance_image_path for item in items], |
| 276 | base_size=size, | 290 | base_size=size, |
| 291 | step_size=bucket_step_size, | ||
| 277 | num_buckets=num_buckets, | 292 | num_buckets=num_buckets, |
| 293 | max_pixels=max_pixels_per_bucket, | ||
| 278 | progressive_buckets=progressive_buckets, | 294 | progressive_buckets=progressive_buckets, |
| 279 | ) | 295 | ) |
| 280 | 296 | ||
