From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- data/dreambooth/csv.py | 181 ------------------------------------------------- 1 file changed, 181 deletions(-) delete mode 100644 data/dreambooth/csv.py (limited to 'data/dreambooth/csv.py') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py deleted file mode 100644 index abd329d..0000000 --- a/data/dreambooth/csv.py +++ /dev/null @@ -1,181 +0,0 @@ -import math -import os -import pandas as pd -from pathlib import Path -import pytorch_lightning as pl -from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms - - -class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="db_cls", - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, - collate_fn=None): - super().__init__() - - self.data_file = Path(data_file) - - if not self.data_file.is_file(): - raise ValueError("data_file must be a file") - - self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath(class_subdir) - self.class_root.mkdir(parents=True, exist_ok=True) - - self.tokenizer = tokenizer - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier - self.size = size - self.repeats = repeats - self.center_crop = center_crop - self.interpolation = interpolation - self.valid_set_size = valid_set_size - self.generator = generator - self.collate_fn = collate_fn - self.batch_size = batch_size - - def prepare_data(self): - metadata = pd.read_csv(self.data_file) - instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] - class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [(i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x"] - - def setup(self, stage=None): - valid_set_size = int(len(self.data) * 0.2) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) - valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data) - valid_set_size - - self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, - instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, - instance_identifier=self.instance_identifier, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, - pin_memory=True, collate_fn=self.collate_fn) - - def train_dataloader(self): - return self.train_dataloader_ - - def val_dataloader(self): - return self.val_dataloader_ - - -class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - instance_identifier, - class_identifier=None, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, - ): - - self.data = data - self.tokenizer = tokenizer - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier - self.cache = {} - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def get_example(self, i): - instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - - if instance_image_path in self.cache: - return self.cache[instance_image_path] - - example = {} - - example["prompts"] = prompt - example["nprompts"] = nprompt - - instance_image = Image.open(instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - - example["instance_images"] = instance_image - example["instance_prompt_ids"] = self.tokenizer( - prompt.format(self.instance_identifier), - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - - if self.class_identifier is not None: - class_image = Image.open(class_image_path) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - - example["class_images"] = class_image - example["class_prompt_ids"] = self.tokenizer( - prompt.format(self.class_identifier), - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - - self.cache[instance_image_path] = example - return example - - def __getitem__(self, i): - example = {} - unprocessed_example = self.get_example(i) - - example["prompts"] = unprocessed_example["prompts"] - example["nprompts"] = unprocessed_example["nprompts"] - example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - - if self.class_identifier is not None: - example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] - - return example -- cgit v1.2.3-54-g00ecf