From 32599d2e00f051413851a18dbd01fecdd8edfc62 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 15 Feb 2023 14:53:59 +0100 Subject: Improved batch padding --- data/csv.py | 55 ++++++++++++++++++++++++++----------------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/data/csv.py b/data/csv.py index 480e9f2..233f5d8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -399,23 +399,20 @@ class VlpnDataset(IterableDataset): mask[:start] = False mask[end:] = False - while mask.any(): + while mask.any() or len(batch) != 0: + if len(batch) >= batch_size: + yield batch + batch = [] + bucket_mask = mask.logical_and(self.bucket_assignments == bucket) bucket_items = self.bucket_items[bucket_mask] - if len(batch) >= batch_size: + if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch: yield batch batch = [] + continue - if len(bucket_items) == 0: - if len(batch) != 0: - if self.fill_batch: - fill_items = self.bucket_items[self.bucket_assignments == bucket] - fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) - batch += fill_items[fill_perm] - yield batch - batch = [] - + if len(bucket_items) == 0 and len(batch) == 0: bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = int(self.size * ratio) if ratio > 1 else self.size @@ -430,30 +427,30 @@ class VlpnDataset(IterableDataset): transforms.Normalize([0.5], [0.5]), ] ) + + continue + + if len(bucket_items) == 0: + bucket_items = self.bucket_items[self.bucket_assignments == bucket] + item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] else: item_index = bucket_items[0] - item = self.items[item_index] mask[self.bucket_item_range[bucket_mask][0]] = False - example = {} + item = self.items[item_index] - example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.get_input_ids(item.nprompt) + example = {} - example["instance_prompt_ids"] = self.get_input_ids( - keywords_to_prompt(item.prompt, self.dropout, True) - ) - example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.get_input_ids(item.nprompt) - if self.num_class_images != 0: - example["class_prompt_ids"] = self.get_input_ids(item.cprompt) - example["class_images"] = image_transforms(get_image(item.class_image_path)) + example["instance_prompt_ids"] = self.get_input_ids( + keywords_to_prompt(item.prompt, self.dropout, True) + ) + example["instance_images"] = image_transforms(get_image(item.instance_image_path)) - batch.append(example) + if self.num_class_images != 0: + example["class_prompt_ids"] = self.get_input_ids(item.cprompt) + example["class_images"] = image_transforms(get_image(item.class_image_path)) - if len(batch) != 0: - if self.fill_batch: - fill_items = self.bucket_items[self.bucket_assignments == bucket] - fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) - batch += fill_items[fill_perm] - yield batch + batch.append(example) -- cgit v1.2.3-54-g00ecf