From 6970adaff742ac89adb3d85c803689210dc030e2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 15:05:39 +0100 Subject: Made aspect ratio bucketing configurable --- data/csv.py | 33 ++++++++++++++++++++++----------- train_ti.py | 13 +++++++++++++ training/util.py | 9 ++------- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/data/csv.py b/data/csv.py index 59d6d8d..654aec1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable +import numpy as np + from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt @@ -56,6 +58,8 @@ class VlpnDataModule(): class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, + num_aspect_ratio_buckets: int = 0, + progressive_aspect_ratio_buckets: bool = False, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", @@ -80,6 +84,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size + self.num_aspect_ratio_buckets = num_aspect_ratio_buckets + self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets self.repeats = repeats self.dropout = dropout self.template_key = template_key @@ -143,25 +149,32 @@ class VlpnDataModule(): def generate_buckets(self, items: list[VlpnDataItem]): buckets = [VlpnDataBucket(self.size, self.size)] - for i in range(1, 5): + for i in range(1, self.num_aspect_ratio_buckets + 1): s = self.size + i * 64 buckets.append(VlpnDataBucket(s, self.size)) buckets.append(VlpnDataBucket(self.size, s)) + buckets = np.array(buckets) + bucket_ratios = np.array([bucket.ratio for bucket in buckets]) + for item in items: image = get_image(item.instance_image_path) ratio = image.width / image.height if ratio >= 1: - candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] + mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) else: - candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] + mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) - for bucket in candidates: + if not self.progressive_aspect_ratio_buckets: + ratios = bucket_ratios.copy() + ratios[~mask] = math.inf + mask = [np.argmin(np.abs(ratios - ratio))] + + for bucket in buckets[mask]: bucket.items.append(item) - buckets = [bucket for bucket in buckets if len(bucket.items) != 0] - return buckets + return [bucket for bucket in buckets if len(bucket.items) != 0] def setup(self): with open(self.data_file, 'rt') as f: @@ -192,7 +205,7 @@ class VlpnDataModule(): train_datasets = [ VlpnDataset( - bucket.items, self.prompt_processor, batch_size=self.batch_size, + bucket.items, self.prompt_processor, width=bucket.width, height=bucket.height, interpolation=self.interpolation, num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, ) @@ -200,7 +213,7 @@ class VlpnDataModule(): ] val_dataset = VlpnDataset( - data_val, self.prompt_processor, batch_size=self.batch_size, + data_val, self.prompt_processor, width=self.size, height=self.size, interpolation=self.interpolation, ) @@ -223,7 +236,6 @@ class VlpnDataset(Dataset): self, data: List[VlpnDataItem], prompt_processor: PromptProcessor, - batch_size: int = 1, num_class_images: int = 0, width: int = 768, height: int = 768, @@ -234,7 +246,6 @@ class VlpnDataset(Dataset): self.data = data self.prompt_processor = prompt_processor - self.batch_size = batch_size self.num_class_images = num_class_images self.dropout = dropout @@ -258,7 +269,7 @@ class VlpnDataset(Dataset): ) def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size + return self._length def get_example(self, i): item = self.data[i % self.num_instance_images] diff --git a/train_ti.py b/train_ti.py index 89c6672..38c9755 100644 --- a/train_ti.py +++ b/train_ti.py @@ -145,6 +145,17 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--num_aspect_ratio_buckets", + type=int, + default=4, + help="Number of buckets in either direction (adds 64 pixels per step).", + ) + parser.add_argument( + "--progressive_aspect_ratio_buckets", + action="store_true", + help="Include images in smaller buckets as well.", + ) parser.add_argument( "--tag_dropout", type=float, @@ -710,6 +721,8 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, + num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, + progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, repeats=args.repeats, dropout=args.tag_dropout, template_key=args.train_data_template, diff --git a/training/util.py b/training/util.py index 6f42228..2b7f71d 100644 --- a/training/util.py +++ b/training/util.py @@ -1,6 +1,7 @@ from pathlib import Path import json import copy +import itertools from typing import Iterable, Optional from contextlib import contextmanager @@ -71,13 +72,7 @@ class CheckpointerBase: file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] + batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) prompts = [ prompt for batch in batches -- cgit v1.2.3-70-g09d2