diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -44,18 +44,25 @@ def generate_buckets( | |||
44 | items: list[str], | 44 | items: list[str], |
45 | base_size: int, | 45 | base_size: int, |
46 | step_size: int = 64, | 46 | step_size: int = 64, |
47 | max_pixels: Optional[int] = None, | ||
47 | num_buckets: int = 4, | 48 | num_buckets: int = 4, |
48 | progressive_buckets: bool = False, | 49 | progressive_buckets: bool = False, |
49 | return_tensor: bool = True | 50 | return_tensor: bool = True |
50 | ): | 51 | ): |
52 | if max_pixels is None: | ||
53 | max_pixels = (base_size + step_size) ** 2 | ||
54 | |||
55 | max_pixels = max(max_pixels, base_size * base_size) | ||
56 | |||
51 | bucket_items: list[int] = [] | 57 | bucket_items: list[int] = [] |
52 | bucket_assignments: list[int] = [] | 58 | bucket_assignments: list[int] = [] |
53 | buckets = [1.0] | 59 | buckets = [1.0] |
54 | 60 | ||
55 | for i in range(1, num_buckets + 1): | 61 | for i in range(1, num_buckets + 1): |
56 | s = base_size + i * step_size | 62 | long_side = base_size + i * step_size |
57 | buckets.append(s / base_size) | 63 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) |
58 | buckets.append(base_size / s) | 64 | buckets.append(long_side / short_side) |
65 | buckets.append(short_side / long_side) | ||
59 | 66 | ||
60 | buckets = torch.tensor(buckets) | 67 | buckets = torch.tensor(buckets) |
61 | bucket_indices = torch.arange(len(buckets)) | 68 | bucket_indices = torch.arange(len(buckets)) |
@@ -110,6 +117,8 @@ class VlpnDataModule(): | |||
110 | num_class_images: int = 1, | 117 | num_class_images: int = 1, |
111 | size: int = 768, | 118 | size: int = 768, |
112 | num_buckets: int = 0, | 119 | num_buckets: int = 0, |
120 | bucket_step_size: int = 64, | ||
121 | max_pixels_per_bucket: Optional[int] = None, | ||
113 | progressive_buckets: bool = False, | 122 | progressive_buckets: bool = False, |
114 | dropout: float = 0, | 123 | dropout: float = 0, |
115 | interpolation: str = "bicubic", | 124 | interpolation: str = "bicubic", |
@@ -135,6 +144,8 @@ class VlpnDataModule(): | |||
135 | self.prompt_processor = prompt_processor | 144 | self.prompt_processor = prompt_processor |
136 | self.size = size | 145 | self.size = size |
137 | self.num_buckets = num_buckets | 146 | self.num_buckets = num_buckets |
147 | self.bucket_step_size = bucket_step_size | ||
148 | self.max_pixels_per_bucket = max_pixels_per_bucket | ||
138 | self.progressive_buckets = progressive_buckets | 149 | self.progressive_buckets = progressive_buckets |
139 | self.dropout = dropout | 150 | self.dropout = dropout |
140 | self.template_key = template_key | 151 | self.template_key = template_key |
@@ -223,6 +234,7 @@ class VlpnDataModule(): | |||
223 | train_dataset = VlpnDataset( | 234 | train_dataset = VlpnDataset( |
224 | self.data_train, self.prompt_processor, | 235 | self.data_train, self.prompt_processor, |
225 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
237 | bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, | ||
226 | batch_size=self.batch_size, generator=generator, | 238 | batch_size=self.batch_size, generator=generator, |
227 | size=self.size, interpolation=self.interpolation, | 239 | size=self.size, interpolation=self.interpolation, |
228 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
@@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): | |||
251 | items: list[VlpnDataItem], | 263 | items: list[VlpnDataItem], |
252 | prompt_processor: PromptProcessor, | 264 | prompt_processor: PromptProcessor, |
253 | num_buckets: int = 1, | 265 | num_buckets: int = 1, |
266 | bucket_step_size: int = 64, | ||
267 | max_pixels_per_bucket: Optional[int] = None, | ||
254 | progressive_buckets: bool = False, | 268 | progressive_buckets: bool = False, |
255 | batch_size: int = 1, | 269 | batch_size: int = 1, |
256 | num_class_images: int = 0, | 270 | num_class_images: int = 0, |
@@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): | |||
274 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 288 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
275 | [item.instance_image_path for item in items], | 289 | [item.instance_image_path for item in items], |
276 | base_size=size, | 290 | base_size=size, |
291 | step_size=bucket_step_size, | ||
277 | num_buckets=num_buckets, | 292 | num_buckets=num_buckets, |
293 | max_pixels=max_pixels_per_bucket, | ||
278 | progressive_buckets=progressive_buckets, | 294 | progressive_buckets=progressive_buckets, |
279 | ) | 295 | ) |
280 | 296 | ||