import math import pandas as pd import torch from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import NamedTuple, List, Optional from models.clip.prompt import PromptProcessor class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: str nprompt: str class CSVDataModule(pl.LightningDataModule): def __init__( self, batch_size: int, data_file: str, prompt_processor: PromptProcessor, instance_identifier: str, class_identifier: Optional[str] = None, class_subdir: str = "cls", num_class_images: int = 100, size: int = 512, repeats: int = 1, interpolation: str = "bicubic", center_crop: bool = False, valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, collate_fn=None ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.prompt_processor = prompt_processor self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size self.repeats = repeats self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn self.batch_size = batch_size def prepare_subdata(self, data, num_class_images=1): image_multiplier = max(math.ceil(num_class_images / len(data)), 1) return [ CSVDataItem( self.data_root.joinpath(item.image), self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), item.prompt, item.nprompt ) for item in data for i in range(image_multiplier) ] def prepare_data(self): metadata = pd.read_json(self.data_file) metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] num_images = len(metadata) valid_set_size = int(num_images * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) self.data_train = self.prepare_subdata(data_train, self.num_class_images) self.data_val = self.prepare_subdata(data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], prompt_processor: PromptProcessor, instance_identifier: str, batch_size: int = 1, class_identifier: Optional[str] = None, num_class_images: int = 0, size: int = 512, repeats: int = 1, interpolation: str = "bicubic", center_crop: bool = False, ): self.data = data self.prompt_processor = prompt_processor self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.cache = {} self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size def get_image(self, path): if path in self.image_cache: return self.image_cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") self.image_cache[path] = image return image def get_example(self, i): item = self.data[i % self.num_instance_images] cache_key = f"{item.instance_image_path}_{item.class_image_path}" if cache_key in self.cache: return self.cache[cache_key] example = {} example["prompts"] = item.prompt example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( item.prompt.format(self.instance_identifier) ) if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) example["class_prompt_ids"] = self.prompt_processor.get_input_ids( item.nprompt.format(self.class_identifier) ) self.cache[cache_key] = example return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] return example