From 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 09:43:22 +0100 Subject: Improved aspect ratio bucketing --- data/csv.py | 273 ++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 154 insertions(+), 119 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 654aec1..9be36ba 100644 --- a/data/csv.py +++ b/data/csv.py @@ -2,20 +2,28 @@ import math import torch import json from pathlib import Path +from typing import NamedTuple, Optional, Union, Callable + from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms -from typing import Dict, NamedTuple, List, Optional, Union, Callable -import numpy as np +from torch.utils.data import IterableDataset, DataLoader, random_split +from torchvision import transforms -from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt +from models.clip.prompt import PromptProcessor image_cache: dict[str, Image.Image] = {} +interpolations = { + "linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, +} + + def get_image(path): if path in image_cache: return image_cache[path] @@ -28,10 +36,46 @@ def get_image(path): return image -def prepare_prompt(prompt: Union[str, Dict[str, str]]): +def prepare_prompt(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt +def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): + item_order: list[int] = [] + item_buckets: list[int] = [] + buckets = [1.0] + + for i in range(1, num_buckets + 1): + s = size + i * 64 + buckets.append(s / size) + buckets.append(size / s) + + buckets = torch.tensor(buckets) + bucket_indices = torch.arange(len(buckets)) + + for i, item in enumerate(items): + image = get_image(item) + ratio = image.width / image.height + + if ratio >= 1: + mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) + else: + mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) + + if not progressive_buckets: + mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() + + indices = bucket_indices[mask] + + if len(indices.shape) == 0: + indices = indices.unsqueeze(0) + + item_order += [i] * len(indices) + item_buckets += indices + + return buckets.tolist(), item_order, item_buckets + + class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple): collection: list[str] -class VlpnDataBucket(): - def __init__(self, width: int, height: int): - self.width = width - self.height = height - self.ratio = width / height - self.items: list[VlpnDataItem] = [] - - class VlpnDataModule(): def __init__( self, @@ -60,7 +96,6 @@ class VlpnDataModule(): size: int = 768, num_aspect_ratio_buckets: int = 0, progressive_aspect_ratio_buckets: bool = False, - repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", @@ -86,7 +121,6 @@ class VlpnDataModule(): self.size = size self.num_aspect_ratio_buckets = num_aspect_ratio_buckets self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets - self.repeats = repeats self.dropout = dropout self.template_key = template_key self.interpolation = interpolation @@ -146,36 +180,6 @@ class VlpnDataModule(): for i in range(image_multiplier) ] - def generate_buckets(self, items: list[VlpnDataItem]): - buckets = [VlpnDataBucket(self.size, self.size)] - - for i in range(1, self.num_aspect_ratio_buckets + 1): - s = self.size + i * 64 - buckets.append(VlpnDataBucket(s, self.size)) - buckets.append(VlpnDataBucket(self.size, s)) - - buckets = np.array(buckets) - bucket_ratios = np.array([bucket.ratio for bucket in buckets]) - - for item in items: - image = get_image(item.instance_image_path) - ratio = image.width / image.height - - if ratio >= 1: - mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) - else: - mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) - - if not self.progressive_aspect_ratio_buckets: - ratios = bucket_ratios.copy() - ratios[~mask] = math.inf - mask = [np.argmin(np.abs(ratios - ratio))] - - for bucket in buckets[mask]: - bucket.items.append(item) - - return [bucket for bucket in buckets if len(bucket.items) != 0] - def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) @@ -201,105 +205,136 @@ class VlpnDataModule(): self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) - buckets = self.generate_buckets(data_train) - - train_datasets = [ - VlpnDataset( - bucket.items, self.prompt_processor, - width=bucket.width, height=bucket.height, interpolation=self.interpolation, - num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, - ) - for bucket in buckets - ] + train_dataset = VlpnDataset( + self.data_train, self.prompt_processor, + num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, + batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, + num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, + ) val_dataset = VlpnDataset( - data_val, self.prompt_processor, - width=self.size, height=self.size, interpolation=self.interpolation, + self.data_val, self.prompt_processor, + batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, ) - self.train_dataloaders = [ - DataLoader( - dataset, batch_size=self.batch_size, shuffle=True, - pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers - ) - for dataset in train_datasets - ] + self.train_dataloader = DataLoader( + train_dataset, + batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) self.val_dataloader = DataLoader( - val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + val_dataset, + batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) -class VlpnDataset(Dataset): +class VlpnDataset(IterableDataset): def __init__( self, - data: List[VlpnDataItem], + items: list[VlpnDataItem], prompt_processor: PromptProcessor, + num_buckets: int = 1, + progressive_buckets: bool = False, + batch_size: int = 1, num_class_images: int = 0, - width: int = 768, - height: int = 768, - repeats: int = 1, + size: int = 768, dropout: float = 0, + shuffle: bool = False, interpolation: str = "bicubic", + generator: Optional[torch.Generator] = None, ): + self.items = items + self.batch_size = batch_size - self.data = data self.prompt_processor = prompt_processor self.num_class_images = num_class_images + self.size = size self.dropout = dropout - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = { - "linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(min(width, height), interpolation=self.interpolation), - transforms.RandomCrop((height, width)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] + self.shuffle = shuffle + self.interpolation = interpolations[interpolation] + self.generator = generator + + buckets, item_order, item_buckets = generate_buckets( + [item.instance_image_path for item in items], + size, + num_buckets, + progressive_buckets ) - def __len__(self): - return self._length + self.buckets = torch.tensor(buckets) + self.item_order = torch.tensor(item_order) + self.item_buckets = torch.tensor(item_buckets) - def get_example(self, i): - item = self.data[i % self.num_instance_images] - - example = {} - example["prompts"] = item.prompt - example["cprompts"] = item.cprompt - example["nprompts"] = item.nprompt - example["instance_images"] = get_image(item.instance_image_path) - if self.num_class_images != 0: - example["class_images"] = get_image(item.class_image_path) - - return example + def __len__(self): + return len(self.item_buckets) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + + if self.shuffle: + perm = torch.randperm(len(self.item_buckets), generator=self.generator) + self.item_order = self.item_order[perm] + self.item_buckets = self.item_buckets[perm] + + item_mask = torch.ones_like(self.item_buckets, dtype=bool) + bucket = -1 + image_transforms = None + batch = [] + batch_size = self.batch_size + + if worker_info is not None: + batch_size = math.ceil(batch_size / worker_info.num_workers) + worker_batch = math.ceil(len(self) / worker_info.num_workers) + start = worker_info.id * worker_batch + end = start + worker_batch + item_mask[:start] = False + item_mask[end:] = False + + while item_mask.any(): + item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] + + if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): + yield batch + batch = [] + + if len(item_indices) == 0: + bucket = self.item_buckets[item_mask][0] + ratio = self.buckets[bucket] + width = self.size * ratio if ratio > 1 else self.size + height = self.size / ratio if ratio < 1 else self.size + + image_transforms = transforms.Compose( + [ + transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.RandomCrop((height, width)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + item_index = item_indices[0] + item = self.items[item_index] + item_mask[item_index] = False - def __getitem__(self, i): - unprocessed_example = self.get_example(i) + example = {} - example = {} + example["prompts"] = keywords_to_prompt(item.prompt) + example["cprompts"] = item.cprompt + example["nprompts"] = item.nprompt - example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) - example["cprompts"] = unprocessed_example["cprompts"] - example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + keywords_to_prompt(item.prompt, self.dropout, True) + ) - example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( - keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) - ) + if self.num_class_images != 0: + example["class_images"] = image_transforms(get_image(item.class_image_path)) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) - if self.num_class_images != 0: - example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) + batch.append(example) - return example + if len(batch) != 0: + yield batch -- cgit v1.2.3-70-g09d2