diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 78 |
1 files changed, 43 insertions, 35 deletions
diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
41 | 41 | ||
42 | 42 | ||
43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): |
44 | item_order: list[int] = [] | 44 | bucket_items: list[int] = [] |
45 | item_buckets: list[int] = [] | 45 | bucket_assignments: list[int] = [] |
46 | buckets = [1.0] | 46 | buckets = [1.0] |
47 | 47 | ||
48 | for i in range(1, num_buckets + 1): | 48 | for i in range(1, num_buckets + 1): |
@@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
70 | if len(indices.shape) == 0: | 70 | if len(indices.shape) == 0: |
71 | indices = indices.unsqueeze(0) | 71 | indices = indices.unsqueeze(0) |
72 | 72 | ||
73 | item_order += [i] * len(indices) | 73 | bucket_items += [i] * len(indices) |
74 | item_buckets += indices | 74 | bucket_assignments += indices |
75 | 75 | ||
76 | return buckets.tolist(), item_order, item_buckets | 76 | return buckets.tolist(), bucket_items, bucket_assignments |
77 | 77 | ||
78 | 78 | ||
79 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
@@ -94,8 +94,8 @@ class VlpnDataModule(): | |||
94 | class_subdir: str = "cls", | 94 | class_subdir: str = "cls", |
95 | num_class_images: int = 1, | 95 | num_class_images: int = 1, |
96 | size: int = 768, | 96 | size: int = 768, |
97 | num_aspect_ratio_buckets: int = 0, | 97 | num_buckets: int = 0, |
98 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_buckets: bool = False, |
99 | dropout: float = 0, | 99 | dropout: float = 0, |
100 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
101 | template_key: str = "template", | 101 | template_key: str = "template", |
@@ -119,8 +119,8 @@ class VlpnDataModule(): | |||
119 | 119 | ||
120 | self.prompt_processor = prompt_processor | 120 | self.prompt_processor = prompt_processor |
121 | self.size = size | 121 | self.size = size |
122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_buckets = num_buckets |
123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_buckets = progressive_buckets |
124 | self.dropout = dropout | 124 | self.dropout = dropout |
125 | self.template_key = template_key | 125 | self.template_key = template_key |
126 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
@@ -207,15 +207,15 @@ class VlpnDataModule(): | |||
207 | 207 | ||
208 | train_dataset = VlpnDataset( | 208 | train_dataset = VlpnDataset( |
209 | self.data_train, self.prompt_processor, | 209 | self.data_train, self.prompt_processor, |
210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, | 210 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
211 | batch_size=self.batch_size, | 211 | batch_size=self.batch_size, generator=generator, |
212 | size=self.size, interpolation=self.interpolation, | 212 | size=self.size, interpolation=self.interpolation, |
213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
214 | ) | 214 | ) |
215 | 215 | ||
216 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
217 | self.data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
218 | batch_size=self.batch_size, | 218 | batch_size=self.batch_size, generator=generator, |
219 | size=self.size, interpolation=self.interpolation, | 219 | size=self.size, interpolation=self.interpolation, |
220 | ) | 220 | ) |
221 | 221 | ||
@@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): | |||
256 | self.interpolation = interpolations[interpolation] | 256 | self.interpolation = interpolations[interpolation] |
257 | self.generator = generator | 257 | self.generator = generator |
258 | 258 | ||
259 | buckets, item_order, item_buckets = generate_buckets( | 259 | buckets, bucket_items, bucket_assignments = generate_buckets( |
260 | [item.instance_image_path for item in items], | 260 | [item.instance_image_path for item in items], |
261 | size, | 261 | size, |
262 | num_buckets, | 262 | num_buckets, |
@@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): | |||
264 | ) | 264 | ) |
265 | 265 | ||
266 | self.buckets = torch.tensor(buckets) | 266 | self.buckets = torch.tensor(buckets) |
267 | self.item_order = torch.tensor(item_order) | 267 | self.bucket_items = torch.tensor(bucket_items) |
268 | self.item_buckets = torch.tensor(item_buckets) | 268 | self.bucket_assignments = torch.tensor(bucket_assignments) |
269 | self.bucket_item_range = torch.arange(len(bucket_items)) | ||
270 | |||
271 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | ||
269 | 272 | ||
270 | def __len__(self): | 273 | def __len__(self): |
271 | return len(self.item_buckets) | 274 | return self.length_ |
272 | 275 | ||
273 | def __iter__(self): | 276 | def __iter__(self): |
274 | worker_info = torch.utils.data.get_worker_info() | 277 | worker_info = torch.utils.data.get_worker_info() |
275 | 278 | ||
276 | if self.shuffle: | 279 | if self.shuffle: |
277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) | 280 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) |
278 | self.item_order = self.item_order[perm] | 281 | self.bucket_items = self.bucket_items[perm] |
279 | self.item_buckets = self.item_buckets[perm] | 282 | self.bucket_assignments = self.bucket_assignments[perm] |
280 | 283 | ||
281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) | ||
282 | bucket = -1 | ||
283 | image_transforms = None | 284 | image_transforms = None |
285 | |||
286 | mask = torch.ones_like(self.bucket_assignments, dtype=bool) | ||
287 | bucket = -1 | ||
284 | batch = [] | 288 | batch = [] |
285 | batch_size = self.batch_size | 289 | batch_size = self.batch_size |
286 | 290 | ||
@@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): | |||
289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | 293 | worker_batch = math.ceil(len(self) / worker_info.num_workers) |
290 | start = worker_info.id * worker_batch | 294 | start = worker_info.id * worker_batch |
291 | end = start + worker_batch | 295 | end = start + worker_batch |
292 | item_mask[:start] = False | 296 | mask[:start] = False |
293 | item_mask[end:] = False | 297 | mask[end:] = False |
294 | 298 | ||
295 | while item_mask.any(): | 299 | while mask.any(): |
296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | 300 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
301 | bucket_items = self.bucket_items[bucket_mask] | ||
297 | 302 | ||
298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | 303 | if len(batch) >= batch_size: |
299 | yield batch | 304 | yield batch |
300 | batch = [] | 305 | batch = [] |
301 | 306 | ||
302 | if len(item_indices) == 0: | 307 | if len(bucket_items) == 0: |
303 | bucket = self.item_buckets[item_mask][0] | 308 | if len(batch) != 0: |
309 | yield batch | ||
310 | batch = [] | ||
311 | |||
312 | bucket = self.bucket_assignments[mask][0] | ||
304 | ratio = self.buckets[bucket] | 313 | ratio = self.buckets[bucket] |
305 | width = self.size * ratio if ratio > 1 else self.size | 314 | width = self.size * ratio if ratio > 1 else self.size |
306 | height = self.size / ratio if ratio < 1 else self.size | 315 | height = self.size / ratio if ratio < 1 else self.size |
307 | 316 | ||
308 | image_transforms = transforms.Compose( | 317 | image_transforms = transforms.Compose( |
309 | [ | 318 | [ |
310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | 319 | transforms.Resize(self.size, interpolation=self.interpolation), |
311 | transforms.RandomCrop((height, width)), | 320 | transforms.RandomCrop((height, width)), |
312 | transforms.RandomHorizontalFlip(), | 321 | transforms.RandomHorizontalFlip(), |
313 | transforms.ToTensor(), | 322 | transforms.ToTensor(), |
@@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): | |||
315 | ] | 324 | ] |
316 | ) | 325 | ) |
317 | else: | 326 | else: |
318 | item_index = item_indices[0] | 327 | item_index = bucket_items[0] |
319 | item = self.items[item_index] | 328 | item = self.items[item_index] |
320 | item_mask[item_index] = False | 329 | mask[self.bucket_item_range[bucket_mask][0]] = False |
321 | 330 | ||
322 | example = {} | 331 | example = {} |
323 | 332 | ||
324 | example["prompts"] = keywords_to_prompt(item.prompt) | 333 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) |
325 | example["cprompts"] = item.cprompt | 334 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) |
326 | example["nprompts"] = item.nprompt | ||
327 | 335 | ||
328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 336 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 337 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
@@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): | |||
332 | 340 | ||
333 | if self.num_class_images != 0: | 341 | if self.num_class_images != 0: |
334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 342 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | 343 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) |
336 | 344 | ||
337 | batch.append(example) | 345 | batch.append(example) |
338 | 346 | ||