From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- data/textual_inversion/csv.py | 150 ------------------------------------------ 1 file changed, 150 deletions(-) delete mode 100644 data/textual_inversion/csv.py (limited to 'data/textual_inversion/csv.py') diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py deleted file mode 100644 index 4c5e27e..0000000 --- a/data/textual_inversion/csv.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import numpy as np -import pandas as pd -from pathlib import Path -import math -import pytorch_lightning as pl -from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms - - -class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - size=512, - repeats=100, - interpolation="bicubic", - placeholder_token="*", - center_crop=False, - valid_set_size=None, - generator=None): - super().__init__() - - self.data_file = Path(data_file) - - if not self.data_file.is_file(): - raise ValueError("data_file must be a file") - - self.data_root = self.data_file.parent - self.tokenizer = tokenizer - self.size = size - self.repeats = repeats - self.placeholder_token = placeholder_token - self.center_crop = center_crop - self.interpolation = interpolation - self.valid_set_size = valid_set_size - self.generator = generator - - self.batch_size = batch_size - - def prepare_data(self): - metadata = pd.read_csv(self.data_file) - image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) - self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] - - def setup(self, stage=None): - valid_set_size = int(len(self.data_full) * 0.2) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) - valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data_full) - valid_set_size - - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - placeholder_token=self.placeholder_token, center_crop=self.center_crop) - val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - placeholder_token=self.placeholder_token, center_crop=self.center_crop) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) - - def train_dataloader(self): - return self.train_dataloader_ - - def val_dataloader(self): - return self.val_dataloader_ - - -class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - size=512, - repeats=1, - interpolation="bicubic", - placeholder_token="*", - center_crop=False, - batch_size=1, - ): - - self.data = data - self.tokenizer = tokenizer - self.placeholder_token = placeholder_token - self.batch_size = batch_size - self.cache = {} - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size - - def get_example(self, i): - image_path, prompt, nprompt = self.data[i % self.num_instance_images] - - if image_path in self.cache: - return self.cache[image_path] - - example = {} - - instance_image = Image.open(image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - - prompt = prompt.format(self.placeholder_token) - - example["prompts"] = prompt - example["nprompts"] = nprompt - example["pixel_values"] = instance_image - example["input_ids"] = self.tokenizer( - prompt, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - self.cache[image_path] = example - return example - - def __getitem__(self, i): - example = {} - unprocessed_example = self.get_example(i) - - example["prompts"] = unprocessed_example["prompts"] - example["nprompts"] = unprocessed_example["nprompts"] - example["input_ids"] = unprocessed_example["input_ids"] - example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) - - return example -- cgit v1.2.3-54-g00ecf