import math import torch import json import numpy as np from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable from models.clip.prompt import PromptProcessor def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str: if dropout != 0: prompt = [keyword for keyword in prompt if np.random.random() > dropout] np.random.shuffle(prompt) return ", ".join(prompt) def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: def expand_keyword(keyword: str) -> list[str]: return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] return [ kw for keyword in prompt.split(", ") for kw in expand_keyword(keyword) if keyword != "" ] class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] nprompt: str class CSVDataModule(pl.LightningDataModule): def __init__( self, batch_size: int, data_file: str, prompt_processor: PromptProcessor, instance_identifier: str, class_identifier: Optional[str] = None, class_subdir: str = "cls", num_class_images: int = 100, size: int = 512, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, mode: Optional[str] = None, template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, filter: Optional[Callable[[CSVDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.prompt_processor = prompt_processor self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size self.repeats = repeats self.dropout = dropout self.center_crop = center_crop self.mode = mode self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ CSVDataItem( self.data_root.joinpath(image.format(item["image"])), None, prompt_to_keywords(prompt.format( **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data ] def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: image_multiplier = max(math.ceil(num_class_images / len(items)), 1) return [ CSVDataItem( item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, item.nprompt, ) for item in items for i in range(image_multiplier) ] def prepare_data(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] if self.mode is not None: items = [ item for item in items if "mode" in item and self.mode in item["mode"] ] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) valid_set_size = int(num_images * 0.1) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], prompt_processor: PromptProcessor, instance_identifier: str, batch_size: int = 1, class_identifier: Optional[str] = None, num_class_images: int = 0, size: int = 512, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, ): self.data = data self.prompt_processor = prompt_processor self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.dropout = dropout self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size def get_image(self, path): if path in self.image_cache: return self.image_cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") self.image_cache[path] = image return image def get_input_ids(self, prompt, identifier): return self.prompt_processor.get_input_ids(prompt.format(identifier)) def get_example(self, i): item = self.data[i % self.num_instance_images] example = {} example["prompts"] = item.prompt example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) return example def __getitem__(self, i): unprocessed_example = self.get_example(i) example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) return example