import math import os import pandas as pd from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, data_file, tokenizer, instance_identifier, class_identifier=None, class_subdir="db_cls", size=512, repeats=100, interpolation="bicubic", center_crop=False, valid_set_size=None, generator=None, collate_fn=None): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.tokenizer = tokenizer self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size self.repeats = repeats self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn self.batch_size = batch_size def prepare_data(self): metadata = pd.read_csv(self.data_file) instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] prompts = metadata['prompt'].values nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) self.data = [(i, c, p, n) for i, c, p, n, s in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) if s != "x"] def setup(self, stage=None): valid_set_size = int(len(self.data) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = len(self.data) - valid_set_size self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__(self, data, tokenizer, instance_identifier, class_identifier=None, size=512, repeats=1, interpolation="bicubic", center_crop=False, ): self.data = data self.tokenizer = tokenizer self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return self._length def get_example(self, i): instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] if instance_image_path in self.cache: return self.cache[instance_image_path] example = {} example["prompts"] = prompt example["nprompts"] = nprompt instance_image = Image.open(instance_image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.class_identifier is not None: class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids self.cache[instance_image_path] = example return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] if self.class_identifier is not None: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] return example