import math import torch import json from pathlib import Path from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable import numpy as np from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt image_cache: dict[str, Image.Image] = {} def get_image(path): if path in image_cache: return image_cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") image_cache[path] = image return image def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] cprompt: str nprompt: str collection: list[str] class VlpnDataBucket(): def __init__(self, width: int, height: int): self.width = width self.height = height self.ratio = width / height self.items: list[VlpnDataItem] = [] class VlpnDataModule(): def __init__( self, batch_size: int, data_file: str, prompt_processor: PromptProcessor, class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, num_aspect_ratio_buckets: int = 0, progressive_aspect_ratio_buckets: bool = False, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.prompt_processor = prompt_processor self.size = size self.num_aspect_ratio_buckets = num_aspect_ratio_buckets self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets self.repeats = repeats self.dropout = dropout self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size self.seed = seed self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ VlpnDataItem( self.data_root.joinpath(image.format(item["image"])), None, prompt_to_keywords( prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), keywords_to_prompt(prompt_to_keywords( cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions )), keywords_to_prompt(prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), item["collection"].split(", ") if "collection" in item else [] ) for item in data ] def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, item.cprompt, item.nprompt, item.collection, ) for item in items for i in range(image_multiplier) ] def generate_buckets(self, items: list[VlpnDataItem]): buckets = [VlpnDataBucket(self.size, self.size)] for i in range(1, self.num_aspect_ratio_buckets + 1): s = self.size + i * 64 buckets.append(VlpnDataBucket(s, self.size)) buckets.append(VlpnDataBucket(self.size, s)) buckets = np.array(buckets) bucket_ratios = np.array([bucket.ratio for bucket in buckets]) for item in items: image = get_image(item.instance_image_path) ratio = image.width / image.height if ratio >= 1: mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) else: mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) if not self.progressive_aspect_ratio_buckets: ratios = bucket_ratios.copy() ratios[~mask] = math.inf mask = [np.argmin(np.abs(ratios - ratio))] for bucket in buckets[mask]: bucket.items.append(item) return [bucket for bucket in buckets if len(bucket.items) != 0] def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size generator = torch.Generator(device="cpu") if self.seed is not None: generator = generator.manual_seed(self.seed) data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) buckets = self.generate_buckets(data_train) train_datasets = [ VlpnDataset( bucket.items, self.prompt_processor, width=bucket.width, height=bucket.height, interpolation=self.interpolation, num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, ) for bucket in buckets ] val_dataset = VlpnDataset( data_val, self.prompt_processor, width=self.size, height=self.size, interpolation=self.interpolation, ) self.train_dataloaders = [ DataLoader( dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) for dataset in train_datasets ] self.val_dataloader = DataLoader( val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) class VlpnDataset(Dataset): def __init__( self, data: List[VlpnDataItem], prompt_processor: PromptProcessor, num_class_images: int = 0, width: int = 768, height: int = 768, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", ): self.data = data self.prompt_processor = prompt_processor self.num_class_images = num_class_images self.dropout = dropout self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = { "linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(min(width, height), interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return self._length def get_example(self, i): item = self.data[i % self.num_instance_images] example = {} example["prompts"] = item.prompt example["cprompts"] = item.cprompt example["nprompts"] = item.nprompt example["instance_images"] = get_image(item.instance_image_path) if self.num_class_images != 0: example["class_images"] = get_image(item.class_image_path) return example def __getitem__(self, i): unprocessed_example = self.get_example(i) example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) example["cprompts"] = unprocessed_example["cprompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) ) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) return example