import os import numpy as np import pandas as pd import random import PIL import pytorch_lightning as pl from PIL import Image import torch from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, data_root, tokenizer, size=512, repeats=100, interpolation="bicubic", placeholder_token="*", flip_p=0.5, center_crop=False): super().__init__() self.data_root = data_root self.tokenizer = tokenizer self.size = size self.repeats = repeats self.placeholder_token = placeholder_token self.center_crop = center_crop self.flip_p = flip_p self.interpolation = interpolation self.batch_size = batch_size def prepare_data(self): metadata = pd.read_csv(f'{self.data_root}/list.csv') image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] captions = [caption for caption in metadata['caption'].values] skips = [skip for skip in metadata['skip'].values] self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] def setup(self, stage=None): train_set_size = int(len(self.data_full) * 0.8) valid_set_size = len(self.data_full) - train_set_size self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__(self, data, tokenizer, size=512, repeats=1, interpolation="bicubic", flip_p=0.5, placeholder_token="*", center_crop=False, ): self.data = data self.tokenizer = tokenizer self.num_images = len(self.data) self._length = self.num_images * repeats self.placeholder_token = placeholder_token self.size = size self.center_crop = center_crop self.interpolation = {"linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.cache = {} def __len__(self): return self._length def get_example(self, i, flipped): image_path, text = self.data[i % self.num_images] if image_path in self.cache: return self.cache[image_path] example = {} image = Image.open(image_path) if not image.mode == "RGB": image = image.convert("RGB") text = text.format(self.placeholder_token) example["prompt"] = text example["input_ids"] = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt", ).input_ids[0] # default to score-sde preprocessing img = np.array(image).astype(np.uint8) if self.center_crop: crop = min(img.shape[0], img.shape[1]) h, w, = img.shape[0], img.shape[1] img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2] image = Image.fromarray(img) image = image.resize((self.size, self.size), resample=self.interpolation) image = self.flip(image) image = np.array(image).astype(np.uint8) image = (image / 127.5 - 1.0).astype(np.float32) example["key"] = "-".join([image_path, "-", str(flipped)]) example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) self.cache[image_path] = example return example def __getitem__(self, i): flipped = random.choice([False, True]) example = self.get_example(i, flipped) return example