From 5d2abb1749b5d2f2f22ad603b5c2bf9182864520 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 10:03:12 +0200 Subject: More cleanup --- data.py | 145 ---------------------------------------------------------------- 1 file changed, 145 deletions(-) delete mode 100644 data.py (limited to 'data.py') diff --git a/data.py b/data.py deleted file mode 100644 index 0d1e96e..0000000 --- a/data.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import numpy as np -import pandas as pd -import random -import PIL -import pytorch_lightning as pl -from PIL import Image -import torch -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms - - -class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_root, - tokenizer, - size=512, - repeats=100, - interpolation="bicubic", - placeholder_token="*", - flip_p=0.5, - center_crop=False): - super().__init__() - - self.data_root = data_root - self.tokenizer = tokenizer - self.size = size - self.repeats = repeats - self.placeholder_token = placeholder_token - self.center_crop = center_crop - self.flip_p = flip_p - self.interpolation = interpolation - - self.batch_size = batch_size - - def prepare_data(self): - metadata = pd.read_csv(f'{self.data_root}/list.csv') - image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - captions = [caption for caption in metadata['caption'].values] - skips = [skip for skip in metadata['skip'].values] - self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] - - def setup(self, stage=None): - train_set_size = int(len(self.data_full) * 0.8) - valid_set_size = len(self.data_full) - train_set_size - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) - val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) - - def train_dataloader(self): - return self.train_dataloader_ - - def val_dataloader(self): - return self.val_dataloader_ - - -class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - size=512, - repeats=1, - interpolation="bicubic", - flip_p=0.5, - placeholder_token="*", - center_crop=False, - ): - - self.data = data - self.tokenizer = tokenizer - - self.num_images = len(self.data) - self._length = self.num_images * repeats - - self.placeholder_token = placeholder_token - - self.size = size - self.center_crop = center_crop - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - self.cache = {} - - def __len__(self): - return self._length - - def get_example(self, i, flipped): - image_path, text = self.data[i % self.num_images] - - if image_path in self.cache: - return self.cache[image_path] - - example = {} - image = Image.open(image_path) - - if not image.mode == "RGB": - image = image.convert("RGB") - - text = text.format(self.placeholder_token) - - example["prompt"] = text - example["input_ids"] = self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] - - image = Image.fromarray(img) - image = image.resize((self.size, self.size), - resample=self.interpolation) - image = self.flip(image) - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["key"] = "-".join([image_path, "-", str(flipped)]) - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - - self.cache[image_path] = example - return example - - def __getitem__(self, i): - flipped = random.choice([False, True]) - example = self.get_example(i, flipped) - return example -- cgit v1.2.3-54-g00ecf