From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- data/csv.py | 78 ++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 43 insertions(+), 35 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): - item_order: list[int] = [] - item_buckets: list[int] = [] + bucket_items: list[int] = [] + bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): @@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ if len(indices.shape) == 0: indices = indices.unsqueeze(0) - item_order += [i] * len(indices) - item_buckets += indices + bucket_items += [i] * len(indices) + bucket_assignments += indices - return buckets.tolist(), item_order, item_buckets + return buckets.tolist(), bucket_items, bucket_assignments class VlpnDataItem(NamedTuple): @@ -94,8 +94,8 @@ class VlpnDataModule(): class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, - num_aspect_ratio_buckets: int = 0, - progressive_aspect_ratio_buckets: bool = False, + num_buckets: int = 0, + progressive_buckets: bool = False, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", @@ -119,8 +119,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size - self.num_aspect_ratio_buckets = num_aspect_ratio_buckets - self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets + self.num_buckets = num_buckets + self.progressive_buckets = progressive_buckets self.dropout = dropout self.template_key = template_key self.interpolation = interpolation @@ -207,15 +207,15 @@ class VlpnDataModule(): train_dataset = VlpnDataset( self.data_train, self.prompt_processor, - num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, - batch_size=self.batch_size, + num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, + batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, ) val_dataset = VlpnDataset( self.data_val, self.prompt_processor, - batch_size=self.batch_size, + batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) @@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): self.interpolation = interpolations[interpolation] self.generator = generator - buckets, item_order, item_buckets = generate_buckets( + buckets, bucket_items, bucket_assignments = generate_buckets( [item.instance_image_path for item in items], size, num_buckets, @@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): ) self.buckets = torch.tensor(buckets) - self.item_order = torch.tensor(item_order) - self.item_buckets = torch.tensor(item_buckets) + self.bucket_items = torch.tensor(bucket_items) + self.bucket_assignments = torch.tensor(bucket_assignments) + self.bucket_item_range = torch.arange(len(bucket_items)) + + self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def __len__(self): - return len(self.item_buckets) + return self.length_ def __iter__(self): worker_info = torch.utils.data.get_worker_info() if self.shuffle: - perm = torch.randperm(len(self.item_buckets), generator=self.generator) - self.item_order = self.item_order[perm] - self.item_buckets = self.item_buckets[perm] + perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) + self.bucket_items = self.bucket_items[perm] + self.bucket_assignments = self.bucket_assignments[perm] - item_mask = torch.ones_like(self.item_buckets, dtype=bool) - bucket = -1 image_transforms = None + + mask = torch.ones_like(self.bucket_assignments, dtype=bool) + bucket = -1 batch = [] batch_size = self.batch_size @@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): worker_batch = math.ceil(len(self) / worker_info.num_workers) start = worker_info.id * worker_batch end = start + worker_batch - item_mask[:start] = False - item_mask[end:] = False + mask[:start] = False + mask[end:] = False - while item_mask.any(): - item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] + while mask.any(): + bucket_mask = mask.logical_and(self.bucket_assignments == bucket) + bucket_items = self.bucket_items[bucket_mask] - if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): + if len(batch) >= batch_size: yield batch batch = [] - if len(item_indices) == 0: - bucket = self.item_buckets[item_mask][0] + if len(bucket_items) == 0: + if len(batch) != 0: + yield batch + batch = [] + + bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = self.size * ratio if ratio > 1 else self.size height = self.size / ratio if ratio < 1 else self.size image_transforms = transforms.Compose( [ - transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.Resize(self.size, interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), @@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): ] ) else: - item_index = item_indices[0] + item_index = bucket_items[0] item = self.items[item_index] - item_mask[item_index] = False + mask[self.bucket_item_range[bucket_mask][0]] = False example = {} - example["prompts"] = keywords_to_prompt(item.prompt) - example["cprompts"] = item.cprompt - example["nprompts"] = item.nprompt + example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( @@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): if self.num_class_images != 0: example["class_images"] = image_transforms(get_image(item.class_image_path)) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) batch.append(example) -- cgit v1.2.3-54-g00ecf