import math import torch import json from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union from models.clip.prompt import PromptProcessor def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: str nprompt: str class CSVDataModule(pl.LightningDataModule): def __init__( self, batch_size: int, data_file: str, prompt_processor: PromptProcessor, instance_identifier: str, class_identifier: Optional[str] = None, class_subdir: str = "cls", num_class_images: int = 100, size: int = 512, repeats: int = 1, interpolation: str = "bicubic", center_crop: bool = False, valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, collate_fn=None, num_workers: int = 0 ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.prompt_processor = prompt_processor self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size self.repeats = repeats self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size def prepare_subdata(self, template, data, num_class_images=1): image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" image_multiplier = max(math.ceil(num_class_images / len(data)), 1) return [ CSVDataItem( self.data_root.joinpath(image.format(item["image"])), self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) ) for item in data for i in range(image_multiplier) ] def prepare_data(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata["template"] if "template" in metadata else {} items = metadata["items"] if "items" in metadata else [] items = [item for item in items if not "skip" in item or item["skip"] != True] num_images = len(items) valid_set_size = int(num_images * 0.1) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) self.data_val = self.prepare_subdata(template, data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], prompt_processor: PromptProcessor, instance_identifier: str, batch_size: int = 1, class_identifier: Optional[str] = None, num_class_images: int = 0, size: int = 512, repeats: int = 1, interpolation: str = "bicubic", center_crop: bool = False, ): self.data = data self.prompt_processor = prompt_processor self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size def get_image(self, path): if path in self.image_cache: return self.image_cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") self.image_cache[path] = image return image def get_input_ids(self, prompt, identifier): return self.prompt_processor.get_input_ids(prompt.format(identifier)) def get_example(self, i): item = self.data[i % self.num_instance_images] example = {} example["prompts"] = item.prompt example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] return example