From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- data/csv.py | 181 ++++++++++++++++++++++++++++++++++++++++++ data/dreambooth/csv.py | 181 ------------------------------------------ data/textual_inversion/csv.py | 150 ---------------------------------- 3 files changed, 181 insertions(+), 331 deletions(-) create mode 100644 data/csv.py delete mode 100644 data/dreambooth/csv.py delete mode 100644 data/textual_inversion/csv.py (limited to 'data') diff --git a/data/csv.py b/data/csv.py new file mode 100644 index 0000000..abd329d --- /dev/null +++ b/data/csv.py @@ -0,0 +1,181 @@ +import math +import os +import pandas as pd +from pathlib import Path +import pytorch_lightning as pl +from PIL import Image +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_file, + tokenizer, + instance_identifier, + class_identifier=None, + class_subdir="db_cls", + size=512, + repeats=100, + interpolation="bicubic", + center_crop=False, + valid_set_size=None, + generator=None, + collate_fn=None): + super().__init__() + + self.data_file = Path(data_file) + + if not self.data_file.is_file(): + raise ValueError("data_file must be a file") + + self.data_root = self.data_file.parent + self.class_root = self.data_root.joinpath(class_subdir) + self.class_root.mkdir(parents=True, exist_ok=True) + + self.tokenizer = tokenizer + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier + self.size = size + self.repeats = repeats + self.center_crop = center_crop + self.interpolation = interpolation + self.valid_set_size = valid_set_size + self.generator = generator + self.collate_fn = collate_fn + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(self.data_file) + instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] + class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] + prompts = metadata['prompt'].values + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [(i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x"] + + def setup(self, stage=None): + valid_set_size = int(len(self.data) * 0.2) + if self.valid_set_size: + valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = max(valid_set_size, 1) + train_set_size = len(self.data) - valid_set_size + + self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, + instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + val_dataset = CSVDataset(self.data_val, self.tokenizer, + instance_identifier=self.instance_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, + shuffle=True, pin_memory=True, collate_fn=self.collate_fn) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, + pin_memory=True, collate_fn=self.collate_fn) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + instance_identifier, + class_identifier=None, + size=512, + repeats=1, + interpolation="bicubic", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier + self.cache = {} + + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats + + self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=self.interpolation), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def get_example(self, i): + instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] + + if instance_image_path in self.cache: + return self.cache[instance_image_path] + + example = {} + + example["prompts"] = prompt + example["nprompts"] = nprompt + + instance_image = Image.open(instance_image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + prompt.format(self.instance_identifier), + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_identifier is not None: + class_image = Image.open(class_image_path) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + + example["class_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + prompt.format(self.class_identifier), + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + self.cache[instance_image_path] = example + return example + + def __getitem__(self, i): + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) + example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + + if self.class_identifier is not None: + example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) + example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + + return example diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py deleted file mode 100644 index abd329d..0000000 --- a/data/dreambooth/csv.py +++ /dev/null @@ -1,181 +0,0 @@ -import math -import os -import pandas as pd -from pathlib import Path -import pytorch_lightning as pl -from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms - - -class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="db_cls", - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, - collate_fn=None): - super().__init__() - - self.data_file = Path(data_file) - - if not self.data_file.is_file(): - raise ValueError("data_file must be a file") - - self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath(class_subdir) - self.class_root.mkdir(parents=True, exist_ok=True) - - self.tokenizer = tokenizer - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier - self.size = size - self.repeats = repeats - self.center_crop = center_crop - self.interpolation = interpolation - self.valid_set_size = valid_set_size - self.generator = generator - self.collate_fn = collate_fn - self.batch_size = batch_size - - def prepare_data(self): - metadata = pd.read_csv(self.data_file) - instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] - class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [(i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x"] - - def setup(self, stage=None): - valid_set_size = int(len(self.data) * 0.2) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) - valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data) - valid_set_size - - self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, - instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, - instance_identifier=self.instance_identifier, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, - pin_memory=True, collate_fn=self.collate_fn) - - def train_dataloader(self): - return self.train_dataloader_ - - def val_dataloader(self): - return self.val_dataloader_ - - -class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - instance_identifier, - class_identifier=None, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, - ): - - self.data = data - self.tokenizer = tokenizer - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier - self.cache = {} - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def get_example(self, i): - instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - - if instance_image_path in self.cache: - return self.cache[instance_image_path] - - example = {} - - example["prompts"] = prompt - example["nprompts"] = nprompt - - instance_image = Image.open(instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - - example["instance_images"] = instance_image - example["instance_prompt_ids"] = self.tokenizer( - prompt.format(self.instance_identifier), - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - - if self.class_identifier is not None: - class_image = Image.open(class_image_path) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - - example["class_images"] = class_image - example["class_prompt_ids"] = self.tokenizer( - prompt.format(self.class_identifier), - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - - self.cache[instance_image_path] = example - return example - - def __getitem__(self, i): - example = {} - unprocessed_example = self.get_example(i) - - example["prompts"] = unprocessed_example["prompts"] - example["nprompts"] = unprocessed_example["nprompts"] - example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - - if self.class_identifier is not None: - example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] - - return example diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py deleted file mode 100644 index 4c5e27e..0000000 --- a/data/textual_inversion/csv.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import numpy as np -import pandas as pd -from pathlib import Path -import math -import pytorch_lightning as pl -from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms - - -class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - size=512, - repeats=100, - interpolation="bicubic", - placeholder_token="*", - center_crop=False, - valid_set_size=None, - generator=None): - super().__init__() - - self.data_file = Path(data_file) - - if not self.data_file.is_file(): - raise ValueError("data_file must be a file") - - self.data_root = self.data_file.parent - self.tokenizer = tokenizer - self.size = size - self.repeats = repeats - self.placeholder_token = placeholder_token - self.center_crop = center_crop - self.interpolation = interpolation - self.valid_set_size = valid_set_size - self.generator = generator - - self.batch_size = batch_size - - def prepare_data(self): - metadata = pd.read_csv(self.data_file) - image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) - self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] - - def setup(self, stage=None): - valid_set_size = int(len(self.data_full) * 0.2) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) - valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data_full) - valid_set_size - - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - placeholder_token=self.placeholder_token, center_crop=self.center_crop) - val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - placeholder_token=self.placeholder_token, center_crop=self.center_crop) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) - - def train_dataloader(self): - return self.train_dataloader_ - - def val_dataloader(self): - return self.val_dataloader_ - - -class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - size=512, - repeats=1, - interpolation="bicubic", - placeholder_token="*", - center_crop=False, - batch_size=1, - ): - - self.data = data - self.tokenizer = tokenizer - self.placeholder_token = placeholder_token - self.batch_size = batch_size - self.cache = {} - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size - - def get_example(self, i): - image_path, prompt, nprompt = self.data[i % self.num_instance_images] - - if image_path in self.cache: - return self.cache[image_path] - - example = {} - - instance_image = Image.open(image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - - prompt = prompt.format(self.placeholder_token) - - example["prompts"] = prompt - example["nprompts"] = nprompt - example["pixel_values"] = instance_image - example["input_ids"] = self.tokenizer( - prompt, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - self.cache[image_path] = example - return example - - def __getitem__(self, i): - example = {} - unprocessed_example = self.get_example(i) - - example["prompts"] = unprocessed_example["prompts"] - example["nprompts"] = unprocessed_example["nprompts"] - example["input_ids"] = unprocessed_example["input_ids"] - example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) - - return example -- cgit v1.2.3-54-g00ecf