From 0f493e1ac8406de061861ed390f283e821180e79 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 11:26:31 +0200 Subject: Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion --- data/dreambooth/csv.py | 1 - data/textual_inversion/csv.py | 98 ++++++++++++++++++++----------------------- 2 files changed, 46 insertions(+), 53 deletions(-) (limited to 'data') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 99bcf12..1676d35 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -2,7 +2,6 @@ import math import os import pandas as pd from pathlib import Path -import PIL import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -1,11 +1,10 @@ import os import numpy as np import pandas as pd -import random -import PIL +from pathlib import Path +import math import pytorch_lightning as pl from PIL import Image -import torch from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms @@ -13,29 +12,32 @@ from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, - data_root, + data_file, tokenizer, size=512, repeats=100, interpolation="bicubic", placeholder_token="*", - flip_p=0.5, center_crop=False): super().__init__() - self.data_root = data_root + self.data_file = Path(data_file) + + if not self.data_file.is_file(): + raise ValueError("data_file must be a file") + + self.data_root = self.data_file.parent self.tokenizer = tokenizer self.size = size self.repeats = repeats self.placeholder_token = placeholder_token self.center_crop = center_crop - self.flip_p = flip_p self.interpolation = interpolation self.batch_size = batch_size def prepare_data(self): - metadata = pd.read_csv(f'{self.data_root}/list.csv') + metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] captions = [caption for caption in metadata['caption'].values] skips = [skip for skip in metadata['skip'].values] @@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + placeholder_token=self.placeholder_token, center_crop=self.center_crop) val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) @@ -67,48 +69,54 @@ class CSVDataset(Dataset): size=512, repeats=1, interpolation="bicubic", - flip_p=0.5, placeholder_token="*", center_crop=False, + batch_size=1, ): self.data = data self.tokenizer = tokenizer - - self.num_images = len(self.data) - self._length = self.num_images * repeats - self.placeholder_token = placeholder_token + self.batch_size = batch_size + self.cache = {} - self.size = size - self.center_crop = center_crop - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats - self.cache = {} + self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=self.interpolation), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size - def get_example(self, i, flipped): - image_path, text = self.data[i % self.num_images] + def get_example(self, i): + image_path, text = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] example = {} - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") + instance_image = Image.open(image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") text = text.format(self.placeholder_token) - example["prompt"] = text + example["prompts"] = text + example["pixel_values"] = instance_image example["input_ids"] = self.tokenizer( text, padding="max_length", @@ -117,29 +125,15 @@ class CSVDataset(Dataset): return_tensors="pt", ).input_ids[0] - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] - - image = Image.fromarray(img) - image = image.resize((self.size, self.size), - resample=self.interpolation) - image = self.flip(image) - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["key"] = "-".join([image_path, "-", str(flipped)]) - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - self.cache[image_path] = example return example def __getitem__(self, i): - flipped = random.choice([False, True]) - example = self.get_example(i, flipped) + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["input_ids"] = unprocessed_example["input_ids"] + example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) + return example -- cgit v1.2.3-70-g09d2