From 5588b93859c4380082a7e46bf5bef2119ec1907a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 26 Sep 2022 16:36:42 +0200 Subject: Init --- data.py | 145 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 data.py (limited to 'data.py') diff --git a/data.py b/data.py new file mode 100644 index 0000000..0d1e96e --- /dev/null +++ b/data.py @@ -0,0 +1,145 @@ +import os +import numpy as np +import pandas as pd +import random +import PIL +import pytorch_lightning as pl +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_root, + tokenizer, + size=512, + repeats=100, + interpolation="bicubic", + placeholder_token="*", + flip_p=0.5, + center_crop=False): + super().__init__() + + self.data_root = data_root + self.tokenizer = tokenizer + self.size = size + self.repeats = repeats + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + self.interpolation = interpolation + + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(f'{self.data_root}/list.csv') + image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + captions = [caption for caption in metadata['caption'].values] + skips = [skip for skip in metadata['skip'].values] + self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + + def setup(self, stage=None): + train_set_size = int(len(self.data_full) * 0.8) + valid_set_size = len(self.data_full) - train_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, + flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, + flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + size=512, + repeats=1, + interpolation="bicubic", + flip_p=0.5, + placeholder_token="*", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + + self.num_images = len(self.data) + self._length = self.num_images * repeats + + self.placeholder_token = placeholder_token + + self.size = size + self.center_crop = center_crop + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.cache = {} + + def __len__(self): + return self._length + + def get_example(self, i, flipped): + image_path, text = self.data[i % self.num_images] + + if image_path in self.cache: + return self.cache[image_path] + + example = {} + image = Image.open(image_path) + + if not image.mode == "RGB": + image = image.convert("RGB") + + text = text.format(self.placeholder_token) + + example["prompt"] = text + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), + resample=self.interpolation) + image = self.flip(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["key"] = "-".join([image_path, "-", str(flipped)]) + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + + self.cache[image_path] = example + return example + + def __getitem__(self, i): + flipped = random.choice([False, True]) + example = self.get_example(i, flipped) + return example -- cgit v1.2.3-54-g00ecf