diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 78 |
1 files changed, 43 insertions, 35 deletions
diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
| 41 | 41 | ||
| 42 | 42 | ||
| 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): |
| 44 | item_order: list[int] = [] | 44 | bucket_items: list[int] = [] |
| 45 | item_buckets: list[int] = [] | 45 | bucket_assignments: list[int] = [] |
| 46 | buckets = [1.0] | 46 | buckets = [1.0] |
| 47 | 47 | ||
| 48 | for i in range(1, num_buckets + 1): | 48 | for i in range(1, num_buckets + 1): |
| @@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
| 70 | if len(indices.shape) == 0: | 70 | if len(indices.shape) == 0: |
| 71 | indices = indices.unsqueeze(0) | 71 | indices = indices.unsqueeze(0) |
| 72 | 72 | ||
| 73 | item_order += [i] * len(indices) | 73 | bucket_items += [i] * len(indices) |
| 74 | item_buckets += indices | 74 | bucket_assignments += indices |
| 75 | 75 | ||
| 76 | return buckets.tolist(), item_order, item_buckets | 76 | return buckets.tolist(), bucket_items, bucket_assignments |
| 77 | 77 | ||
| 78 | 78 | ||
| 79 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
| @@ -94,8 +94,8 @@ class VlpnDataModule(): | |||
| 94 | class_subdir: str = "cls", | 94 | class_subdir: str = "cls", |
| 95 | num_class_images: int = 1, | 95 | num_class_images: int = 1, |
| 96 | size: int = 768, | 96 | size: int = 768, |
| 97 | num_aspect_ratio_buckets: int = 0, | 97 | num_buckets: int = 0, |
| 98 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_buckets: bool = False, |
| 99 | dropout: float = 0, | 99 | dropout: float = 0, |
| 100 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
| 101 | template_key: str = "template", | 101 | template_key: str = "template", |
| @@ -119,8 +119,8 @@ class VlpnDataModule(): | |||
| 119 | 119 | ||
| 120 | self.prompt_processor = prompt_processor | 120 | self.prompt_processor = prompt_processor |
| 121 | self.size = size | 121 | self.size = size |
| 122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_buckets = num_buckets |
| 123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_buckets = progressive_buckets |
| 124 | self.dropout = dropout | 124 | self.dropout = dropout |
| 125 | self.template_key = template_key | 125 | self.template_key = template_key |
| 126 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
| @@ -207,15 +207,15 @@ class VlpnDataModule(): | |||
| 207 | 207 | ||
| 208 | train_dataset = VlpnDataset( | 208 | train_dataset = VlpnDataset( |
| 209 | self.data_train, self.prompt_processor, | 209 | self.data_train, self.prompt_processor, |
| 210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, | 210 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 211 | batch_size=self.batch_size, | 211 | batch_size=self.batch_size, generator=generator, |
| 212 | size=self.size, interpolation=self.interpolation, | 212 | size=self.size, interpolation=self.interpolation, |
| 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
| 214 | ) | 214 | ) |
| 215 | 215 | ||
| 216 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
| 217 | self.data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
| 218 | batch_size=self.batch_size, | 218 | batch_size=self.batch_size, generator=generator, |
| 219 | size=self.size, interpolation=self.interpolation, | 219 | size=self.size, interpolation=self.interpolation, |
| 220 | ) | 220 | ) |
| 221 | 221 | ||
| @@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): | |||
| 256 | self.interpolation = interpolations[interpolation] | 256 | self.interpolation = interpolations[interpolation] |
| 257 | self.generator = generator | 257 | self.generator = generator |
| 258 | 258 | ||
| 259 | buckets, item_order, item_buckets = generate_buckets( | 259 | buckets, bucket_items, bucket_assignments = generate_buckets( |
| 260 | [item.instance_image_path for item in items], | 260 | [item.instance_image_path for item in items], |
| 261 | size, | 261 | size, |
| 262 | num_buckets, | 262 | num_buckets, |
| @@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): | |||
| 264 | ) | 264 | ) |
| 265 | 265 | ||
| 266 | self.buckets = torch.tensor(buckets) | 266 | self.buckets = torch.tensor(buckets) |
| 267 | self.item_order = torch.tensor(item_order) | 267 | self.bucket_items = torch.tensor(bucket_items) |
| 268 | self.item_buckets = torch.tensor(item_buckets) | 268 | self.bucket_assignments = torch.tensor(bucket_assignments) |
| 269 | self.bucket_item_range = torch.arange(len(bucket_items)) | ||
| 270 | |||
| 271 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | ||
| 269 | 272 | ||
| 270 | def __len__(self): | 273 | def __len__(self): |
| 271 | return len(self.item_buckets) | 274 | return self.length_ |
| 272 | 275 | ||
| 273 | def __iter__(self): | 276 | def __iter__(self): |
| 274 | worker_info = torch.utils.data.get_worker_info() | 277 | worker_info = torch.utils.data.get_worker_info() |
| 275 | 278 | ||
| 276 | if self.shuffle: | 279 | if self.shuffle: |
| 277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) | 280 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) |
| 278 | self.item_order = self.item_order[perm] | 281 | self.bucket_items = self.bucket_items[perm] |
| 279 | self.item_buckets = self.item_buckets[perm] | 282 | self.bucket_assignments = self.bucket_assignments[perm] |
| 280 | 283 | ||
| 281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) | ||
| 282 | bucket = -1 | ||
| 283 | image_transforms = None | 284 | image_transforms = None |
| 285 | |||
| 286 | mask = torch.ones_like(self.bucket_assignments, dtype=bool) | ||
| 287 | bucket = -1 | ||
| 284 | batch = [] | 288 | batch = [] |
| 285 | batch_size = self.batch_size | 289 | batch_size = self.batch_size |
| 286 | 290 | ||
| @@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): | |||
| 289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | 293 | worker_batch = math.ceil(len(self) / worker_info.num_workers) |
| 290 | start = worker_info.id * worker_batch | 294 | start = worker_info.id * worker_batch |
| 291 | end = start + worker_batch | 295 | end = start + worker_batch |
| 292 | item_mask[:start] = False | 296 | mask[:start] = False |
| 293 | item_mask[end:] = False | 297 | mask[end:] = False |
| 294 | 298 | ||
| 295 | while item_mask.any(): | 299 | while mask.any(): |
| 296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | 300 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
| 301 | bucket_items = self.bucket_items[bucket_mask] | ||
| 297 | 302 | ||
| 298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | 303 | if len(batch) >= batch_size: |
| 299 | yield batch | 304 | yield batch |
| 300 | batch = [] | 305 | batch = [] |
| 301 | 306 | ||
| 302 | if len(item_indices) == 0: | 307 | if len(bucket_items) == 0: |
| 303 | bucket = self.item_buckets[item_mask][0] | 308 | if len(batch) != 0: |
| 309 | yield batch | ||
| 310 | batch = [] | ||
| 311 | |||
| 312 | bucket = self.bucket_assignments[mask][0] | ||
| 304 | ratio = self.buckets[bucket] | 313 | ratio = self.buckets[bucket] |
| 305 | width = self.size * ratio if ratio > 1 else self.size | 314 | width = self.size * ratio if ratio > 1 else self.size |
| 306 | height = self.size / ratio if ratio < 1 else self.size | 315 | height = self.size / ratio if ratio < 1 else self.size |
| 307 | 316 | ||
| 308 | image_transforms = transforms.Compose( | 317 | image_transforms = transforms.Compose( |
| 309 | [ | 318 | [ |
| 310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | 319 | transforms.Resize(self.size, interpolation=self.interpolation), |
| 311 | transforms.RandomCrop((height, width)), | 320 | transforms.RandomCrop((height, width)), |
| 312 | transforms.RandomHorizontalFlip(), | 321 | transforms.RandomHorizontalFlip(), |
| 313 | transforms.ToTensor(), | 322 | transforms.ToTensor(), |
| @@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): | |||
| 315 | ] | 324 | ] |
| 316 | ) | 325 | ) |
| 317 | else: | 326 | else: |
| 318 | item_index = item_indices[0] | 327 | item_index = bucket_items[0] |
| 319 | item = self.items[item_index] | 328 | item = self.items[item_index] |
| 320 | item_mask[item_index] = False | 329 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| 321 | 330 | ||
| 322 | example = {} | 331 | example = {} |
| 323 | 332 | ||
| 324 | example["prompts"] = keywords_to_prompt(item.prompt) | 333 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) |
| 325 | example["cprompts"] = item.cprompt | 334 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) |
| 326 | example["nprompts"] = item.nprompt | ||
| 327 | 335 | ||
| 328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 336 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 337 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| @@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): | |||
| 332 | 340 | ||
| 333 | if self.num_class_images != 0: | 341 | if self.num_class_images != 0: |
| 334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 342 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
| 335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | 343 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) |
| 336 | 344 | ||
| 337 | batch.append(example) | 345 | batch.append(example) |
| 338 | 346 | ||
