import math import torch import json import copy from pathlib import Path from typing import NamedTuple, Optional, Union, Callable from PIL import Image from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.prompt import PromptProcessor image_cache: dict[str, Image.Image] = {} interpolations = { "linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, } def get_image(path): if path in image_cache: return image_cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") image_cache[path] = image return image def prepare_prompt(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt def generate_buckets( items: list[str], base_size: int, step_size: int = 64, max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, return_tensor: bool = True ): if max_pixels is None: max_pixels = (base_size + step_size) ** 2 max_pixels = max(max_pixels, base_size * base_size) bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): long_side = base_size + i * step_size short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) buckets.append(long_side / short_side) buckets.append(short_side / long_side) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) for i, item in enumerate(items): image = get_image(item) ratio = image.width / image.height if ratio >= 1: mask = torch.logical_and(buckets >= 1, buckets <= ratio) else: mask = torch.logical_and(buckets <= 1, buckets >= ratio) if not progressive_buckets: inf = torch.zeros_like(buckets) inf[~mask] = math.inf mask = (buckets + inf - ratio).abs().argmin() indices = bucket_indices[mask] if len(indices.shape) == 0: indices = indices.unsqueeze(0) bucket_items += [i] * len(indices) bucket_assignments += indices if return_tensor: bucket_items = torch.tensor(bucket_items) bucket_assignments = torch.tensor(bucket_assignments) else: buckets = buckets.tolist() return buckets, bucket_items, bucket_assignments class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] cprompt: str nprompt: str collection: list[str] class VlpnDataModule(): def __init__( self, batch_size: int, data_file: str, prompt_processor: PromptProcessor, class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, num_buckets: int = 0, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.prompt_processor = prompt_processor self.size = size self.num_buckets = num_buckets self.bucket_step_size = bucket_step_size self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets self.dropout = dropout self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size self.seed = seed self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ VlpnDataItem( self.data_root.joinpath(image.format(item["image"])), None, prompt_to_keywords( prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), keywords_to_prompt(prompt_to_keywords( cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions )), keywords_to_prompt(prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), item["collection"].split(", ") if "collection" in item else [] ) for item in data ] def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, item.cprompt, item.nprompt, item.collection, ) for item in items for i in range(image_multiplier) ] def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size generator = torch.Generator(device="cpu") if self.seed is not None: generator = generator.manual_seed(self.seed) data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) train_dataset = VlpnDataset( self.data_train, self.prompt_processor, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, ) val_dataset = VlpnDataset( self.data_val, self.prompt_processor, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) self.train_dataloader = DataLoader( train_dataset, batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) self.val_dataloader = DataLoader( val_dataset, batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) class VlpnDataset(IterableDataset): def __init__( self, items: list[VlpnDataItem], prompt_processor: PromptProcessor, num_buckets: int = 1, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, num_class_images: int = 0, size: int = 768, dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): self.items = items self.batch_size = batch_size self.prompt_processor = prompt_processor self.num_class_images = num_class_images self.size = size self.dropout = dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in items], base_size=size, step_size=bucket_step_size, num_buckets=num_buckets, max_pixels=bucket_max_pixels, progressive_buckets=progressive_buckets, ) self.bucket_item_range = torch.arange(len(self.bucket_items)) self.cache = {} self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def __len__(self): return self.length_ def __iter__(self): worker_info = torch.utils.data.get_worker_info() if self.shuffle: perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) self.bucket_items = self.bucket_items[perm] self.bucket_assignments = self.bucket_assignments[perm] image_transforms = None mask = torch.ones_like(self.bucket_assignments, dtype=bool) bucket = -1 batch = [] batch_size = self.batch_size if worker_info is not None: batch_size = math.ceil(batch_size / worker_info.num_workers) worker_batch = math.ceil(len(self) / worker_info.num_workers) start = worker_info.id * worker_batch end = start + worker_batch mask[:start] = False mask[end:] = False while mask.any(): bucket_mask = mask.logical_and(self.bucket_assignments == bucket) bucket_items = self.bucket_items[bucket_mask] if len(batch) >= batch_size: yield batch batch = [] if len(bucket_items) == 0: if len(batch) != 0: yield batch batch = [] bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = int(self.size * ratio) if ratio > 1 else self.size height = int(self.size / ratio) if ratio < 1 else self.size image_transforms = transforms.Compose( [ transforms.Resize(self.size, interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) else: item_index = bucket_items[0] item = self.items[item_index] mask[self.bucket_item_range[bucket_mask][0]] = False example = {} example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) example["class_images"] = image_transforms(get_image(item.class_image_path)) batch.append(example) if len(batch) != 0: yield batch