From 6970adaff742ac89adb3d85c803689210dc030e2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 15:05:39 +0100 Subject: Made aspect ratio bucketing configurable --- data/csv.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 59d6d8d..654aec1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable +import numpy as np + from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt @@ -56,6 +58,8 @@ class VlpnDataModule(): class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, + num_aspect_ratio_buckets: int = 0, + progressive_aspect_ratio_buckets: bool = False, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", @@ -80,6 +84,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size + self.num_aspect_ratio_buckets = num_aspect_ratio_buckets + self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets self.repeats = repeats self.dropout = dropout self.template_key = template_key @@ -143,25 +149,32 @@ class VlpnDataModule(): def generate_buckets(self, items: list[VlpnDataItem]): buckets = [VlpnDataBucket(self.size, self.size)] - for i in range(1, 5): + for i in range(1, self.num_aspect_ratio_buckets + 1): s = self.size + i * 64 buckets.append(VlpnDataBucket(s, self.size)) buckets.append(VlpnDataBucket(self.size, s)) + buckets = np.array(buckets) + bucket_ratios = np.array([bucket.ratio for bucket in buckets]) + for item in items: image = get_image(item.instance_image_path) ratio = image.width / image.height if ratio >= 1: - candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] + mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) else: - candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] + mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) - for bucket in candidates: + if not self.progressive_aspect_ratio_buckets: + ratios = bucket_ratios.copy() + ratios[~mask] = math.inf + mask = [np.argmin(np.abs(ratios - ratio))] + + for bucket in buckets[mask]: bucket.items.append(item) - buckets = [bucket for bucket in buckets if len(bucket.items) != 0] - return buckets + return [bucket for bucket in buckets if len(bucket.items) != 0] def setup(self): with open(self.data_file, 'rt') as f: @@ -192,7 +205,7 @@ class VlpnDataModule(): train_datasets = [ VlpnDataset( - bucket.items, self.prompt_processor, batch_size=self.batch_size, + bucket.items, self.prompt_processor, width=bucket.width, height=bucket.height, interpolation=self.interpolation, num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, ) @@ -200,7 +213,7 @@ class VlpnDataModule(): ] val_dataset = VlpnDataset( - data_val, self.prompt_processor, batch_size=self.batch_size, + data_val, self.prompt_processor, width=self.size, height=self.size, interpolation=self.interpolation, ) @@ -223,7 +236,6 @@ class VlpnDataset(Dataset): self, data: List[VlpnDataItem], prompt_processor: PromptProcessor, - batch_size: int = 1, num_class_images: int = 0, width: int = 768, height: int = 768, @@ -234,7 +246,6 @@ class VlpnDataset(Dataset): self.data = data self.prompt_processor = prompt_processor - self.batch_size = batch_size self.num_class_images = num_class_images self.dropout = dropout @@ -258,7 +269,7 @@ class VlpnDataset(Dataset): ) def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size + return self._length def get_example(self, i): item = self.data[i % self.num_instance_images] -- cgit v1.2.3-70-g09d2