From 3ee13893f9a4973ac75f45fe9318c35760dd4b1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 13:57:46 +0100 Subject: Added progressive aspect ratio bucketing --- data/csv.py | 144 +++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 88 insertions(+), 56 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt +image_cache: dict[str, Image.Image] = {} + + +def get_image(path): + if path in image_cache: + return image_cache[path] + + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + image_cache[path] = image + + return image + + def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -class CSVDataItem(NamedTuple): +class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] @@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): collection: list[str] -class CSVDataModule(): +class VlpnDataBucket(): + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.ratio = width / height + self.items: list[VlpnDataItem] = [] + + +class VlpnDataModule(): def __init__( self, batch_size: int, @@ -36,11 +59,10 @@ class CSVDataModule(): repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", - center_crop: bool = False, template_key: str = "template", valid_set_size: Optional[int] = None, seed: Optional[int] = None, - filter: Optional[Callable[[CSVDataItem], bool]] = None, + filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): @@ -60,7 +82,6 @@ class CSVDataModule(): self.size = size self.repeats = repeats self.dropout = dropout - self.center_crop = center_crop self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -70,14 +91,14 @@ class CSVDataModule(): self.num_workers = num_workers self.batch_size = batch_size - def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: + def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ - CSVDataItem( + VlpnDataItem( self.data_root.joinpath(image.format(item["image"])), None, prompt_to_keywords( @@ -97,17 +118,17 @@ class CSVDataModule(): for item in data ] - def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: + def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] - def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: + def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ - CSVDataItem( + VlpnDataItem( item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, @@ -119,7 +140,30 @@ class CSVDataModule(): for i in range(image_multiplier) ] - def prepare_data(self): + def generate_buckets(self, items: list[VlpnDataItem]): + buckets = [VlpnDataBucket(self.size, self.size)] + + for i in range(1, 5): + s = self.size + i * 64 + buckets.append(VlpnDataBucket(s, self.size)) + buckets.append(VlpnDataBucket(self.size, s)) + + for item in items: + image = get_image(item.instance_image_path) + ratio = image.width / image.height + + if ratio >= 1: + candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] + else: + candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] + + for bucket in candidates: + bucket.items.append(item) + + buckets = [bucket for bucket in buckets if len(bucket.items) != 0] + return buckets + + def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} @@ -144,48 +188,48 @@ class CSVDataModule(): self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) - def setup(self, stage=None): - train_dataset = CSVDataset( - self.data_train, self.prompt_processor, batch_size=self.batch_size, - num_class_images=self.num_class_images, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout - ) - val_dataset = CSVDataset( - self.data_val, self.prompt_processor, batch_size=self.batch_size, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop - ) - self.train_dataloader_ = DataLoader( - train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers - ) - self.val_dataloader_ = DataLoader( - val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers + buckets = self.generate_buckets(data_train) + + train_datasets = [ + VlpnDataset( + bucket.items, self.prompt_processor, batch_size=self.batch_size, + width=bucket.width, height=bucket.height, interpolation=self.interpolation, + num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, + ) + for bucket in buckets + ] + + val_dataset = VlpnDataset( + data_val, self.prompt_processor, batch_size=self.batch_size, + width=self.size, height=self.size, interpolation=self.interpolation, ) - def train_dataloader(self): - return self.train_dataloader_ + self.train_dataloaders = [ + DataLoader( + dataset, batch_size=self.batch_size, shuffle=True, + pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) + for dataset in train_datasets + ] - def val_dataloader(self): - return self.val_dataloader_ + self.val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, + pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) -class CSVDataset(Dataset): +class VlpnDataset(Dataset): def __init__( self, - data: List[CSVDataItem], + data: List[VlpnDataItem], prompt_processor: PromptProcessor, batch_size: int = 1, num_class_images: int = 0, - size: int = 768, + width: int = 768, + height: int = 768, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", - center_crop: bool = False, ): self.data = data @@ -193,7 +237,6 @@ class CSVDataset(Dataset): self.batch_size = batch_size self.num_class_images = num_class_images self.dropout = dropout - self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -206,8 +249,8 @@ class CSVDataset(Dataset): }[interpolation] self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -217,17 +260,6 @@ class CSVDataset(Dataset): def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size - def get_image(self, path): - if path in self.image_cache: - return self.image_cache[path] - - image = Image.open(path) - if not image.mode == "RGB": - image = image.convert("RGB") - self.image_cache[path] = image - - return image - def get_example(self, i): item = self.data[i % self.num_instance_images] @@ -235,9 +267,9 @@ class CSVDataset(Dataset): example["prompts"] = item.prompt example["cprompts"] = item.cprompt example["nprompts"] = item.nprompt - example["instance_images"] = self.get_image(item.instance_image_path) + example["instance_images"] = get_image(item.instance_image_path) if self.num_class_images != 0: - example["class_images"] = self.get_image(item.class_image_path) + example["class_images"] = get_image(item.class_image_path) return example -- cgit v1.2.3-70-g09d2