From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- data/csv.py | 181 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 data/csv.py (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py new file mode 100644 index 0000000..abd329d --- /dev/null +++ b/data/csv.py @@ -0,0 +1,181 @@ +import math +import os +import pandas as pd +from pathlib import Path +import pytorch_lightning as pl +from PIL import Image +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_file, + tokenizer, + instance_identifier, + class_identifier=None, + class_subdir="db_cls", + size=512, + repeats=100, + interpolation="bicubic", + center_crop=False, + valid_set_size=None, + generator=None, + collate_fn=None): + super().__init__() + + self.data_file = Path(data_file) + + if not self.data_file.is_file(): + raise ValueError("data_file must be a file") + + self.data_root = self.data_file.parent + self.class_root = self.data_root.joinpath(class_subdir) + self.class_root.mkdir(parents=True, exist_ok=True) + + self.tokenizer = tokenizer + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier + self.size = size + self.repeats = repeats + self.center_crop = center_crop + self.interpolation = interpolation + self.valid_set_size = valid_set_size + self.generator = generator + self.collate_fn = collate_fn + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(self.data_file) + instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] + class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] + prompts = metadata['prompt'].values + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [(i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x"] + + def setup(self, stage=None): + valid_set_size = int(len(self.data) * 0.2) + if self.valid_set_size: + valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = max(valid_set_size, 1) + train_set_size = len(self.data) - valid_set_size + + self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, + instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + val_dataset = CSVDataset(self.data_val, self.tokenizer, + instance_identifier=self.instance_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, + shuffle=True, pin_memory=True, collate_fn=self.collate_fn) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, + pin_memory=True, collate_fn=self.collate_fn) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + instance_identifier, + class_identifier=None, + size=512, + repeats=1, + interpolation="bicubic", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier + self.cache = {} + + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats + + self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=self.interpolation), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def get_example(self, i): + instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] + + if instance_image_path in self.cache: + return self.cache[instance_image_path] + + example = {} + + example["prompts"] = prompt + example["nprompts"] = nprompt + + instance_image = Image.open(instance_image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + prompt.format(self.instance_identifier), + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_identifier is not None: + class_image = Image.open(class_image_path) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + + example["class_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + prompt.format(self.class_identifier), + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + self.cache[instance_image_path] = example + return example + + def __getitem__(self, i): + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) + example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + + if self.class_identifier is not None: + example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) + example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + + return example -- cgit v1.2.3-54-g00ecf