From 5d2abb1749b5d2f2f22ad603b5c2bf9182864520 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 10:03:12 +0200 Subject: More cleanup --- data/textual_inversion/csv.py | 134 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 data/textual_inversion/csv.py (limited to 'data/textual_inversion') diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py new file mode 100644 index 0000000..38ffb6f --- /dev/null +++ b/data/textual_inversion/csv.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import pandas as pd +import random +import PIL +import pytorch_lightning as pl +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_root, + tokenizer, + size=512, + repeats=100, + interpolation="bicubic", + placeholder_token="*", + flip_p=0.5, + center_crop=False): + super().__init__() + + self.data_root = data_root + self.tokenizer = tokenizer + self.size = size + self.repeats = repeats + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + self.interpolation = interpolation + + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(f'{self.data_root}/list.csv') + image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + captions = [caption for caption in metadata['caption'].values] + skips = [skip for skip in metadata['skip'].values] + self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + + def setup(self, stage=None): + train_set_size = int(len(self.data_full) * 0.8) + valid_set_size = len(self.data_full) - train_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, + flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, + flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + size=512, + repeats=1, + interpolation="bicubic", + flip_p=0.5, + placeholder_token="*", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + + self.num_images = len(self.data) + self._length = self.num_images * repeats + + self.placeholder_token = placeholder_token + + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cache = {} + + def __len__(self): + return self._length + + def get_example(self, i, flipped): + image_path, text = self.data[i % self.num_images] + + if image_path in self.cache: + return self.cache[image_path] + + example = {} + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = self.image_transforms(image) + + text = text.format(self.placeholder_token) + + example["prompt"] = text + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + example["key"] = "-".join([image_path, "-", str(flipped)]) + example["pixel_values"] = image + + self.cache[image_path] = example + return example + + def __getitem__(self, i): + flipped = random.choice([False, True]) + example = self.get_example(i, flipped) + return example -- cgit v1.2.3-70-g09d2