import math import pandas as pd from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from typing import NamedTuple, List class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: str nprompt: str class CSVDataModule(pl.LightningDataModule): def __init__( self, batch_size, data_file, tokenizer, instance_identifier, class_identifier=None, class_subdir="cls", num_class_images=100, size=512, repeats=100, interpolation="bicubic", center_crop=False, valid_set_size=None, generator=None, collate_fn=None ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.tokenizer = tokenizer self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size self.repeats = repeats self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn self.batch_size = batch_size def prepare_subdata(self, data, num_class_images=1): image_multiplier = max(math.ceil(num_class_images / len(data)), 1) return [ CSVDataItem( self.data_root.joinpath(item.image), self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), item.prompt, item.nprompt if "nprompt" in item else "" ) for item in data for i in range(image_multiplier) ] def prepare_data(self): metadata = pd.read_csv(self.data_file) metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] num_images = len(metadata) valid_set_size = int(num_images * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) self.data_train = self.prepare_subdata(data_train, self.num_class_images) self.data_val = self.prepare_subdata(data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], tokenizer, instance_identifier, batch_size=1, class_identifier=None, num_class_images=0, size=512, repeats=1, interpolation="bicubic", center_crop=False, ): self.data = data self.tokenizer = tokenizer self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.cache = {} self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): item = self.data[i % self.num_instance_images] cache_key = f"{item.instance_image_path}_{item.class_image_path}" if cache_key in self.cache: return self.cache[cache_key] example = {} example["prompts"] = item.prompt example["nprompts"] = item.nprompt if item.instance_image_path in self.image_cache: instance_image = self.image_cache[item.instance_image_path] else: instance_image = Image.open(item.instance_image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") self.image_cache[item.instance_image_path] = instance_image example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( item.prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.num_class_images != 0: class_image = Image.open(item.class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( item.prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids self.cache[item.instance_image_path] = example return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] if self.class_identifier is not None: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] return example