diff options
Diffstat (limited to 'data/textual_inversion')
| -rw-r--r-- | data/textual_inversion/csv.py | 98 |
1 files changed, 46 insertions, 52 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -1,11 +1,10 @@ | |||
| 1 | import os | 1 | import os |
| 2 | import numpy as np | 2 | import numpy as np |
| 3 | import pandas as pd | 3 | import pandas as pd |
| 4 | import random | 4 | from pathlib import Path |
| 5 | import PIL | 5 | import math |
| 6 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
| 7 | from PIL import Image | 7 | from PIL import Image |
| 8 | import torch | ||
| 9 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
| 10 | from torchvision import transforms | 9 | from torchvision import transforms |
| 11 | 10 | ||
| @@ -13,29 +12,32 @@ from torchvision import transforms | |||
| 13 | class CSVDataModule(pl.LightningDataModule): | 12 | class CSVDataModule(pl.LightningDataModule): |
| 14 | def __init__(self, | 13 | def __init__(self, |
| 15 | batch_size, | 14 | batch_size, |
| 16 | data_root, | 15 | data_file, |
| 17 | tokenizer, | 16 | tokenizer, |
| 18 | size=512, | 17 | size=512, |
| 19 | repeats=100, | 18 | repeats=100, |
| 20 | interpolation="bicubic", | 19 | interpolation="bicubic", |
| 21 | placeholder_token="*", | 20 | placeholder_token="*", |
| 22 | flip_p=0.5, | ||
| 23 | center_crop=False): | 21 | center_crop=False): |
| 24 | super().__init__() | 22 | super().__init__() |
| 25 | 23 | ||
| 26 | self.data_root = data_root | 24 | self.data_file = Path(data_file) |
| 25 | |||
| 26 | if not self.data_file.is_file(): | ||
| 27 | raise ValueError("data_file must be a file") | ||
| 28 | |||
| 29 | self.data_root = self.data_file.parent | ||
| 27 | self.tokenizer = tokenizer | 30 | self.tokenizer = tokenizer |
| 28 | self.size = size | 31 | self.size = size |
| 29 | self.repeats = repeats | 32 | self.repeats = repeats |
| 30 | self.placeholder_token = placeholder_token | 33 | self.placeholder_token = placeholder_token |
| 31 | self.center_crop = center_crop | 34 | self.center_crop = center_crop |
| 32 | self.flip_p = flip_p | ||
| 33 | self.interpolation = interpolation | 35 | self.interpolation = interpolation |
| 34 | 36 | ||
| 35 | self.batch_size = batch_size | 37 | self.batch_size = batch_size |
| 36 | 38 | ||
| 37 | def prepare_data(self): | 39 | def prepare_data(self): |
| 38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 40 | metadata = pd.read_csv(self.data_file) |
| 39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 41 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 40 | captions = [caption for caption in metadata['caption'].values] | 42 | captions = [caption for caption in metadata['caption'].values] |
| 41 | skips = [skip for skip in metadata['skip'].values] | 43 | skips = [skip for skip in metadata['skip'].values] |
| @@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
| 48 | 50 | ||
| 49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 52 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 53 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
| 52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 54 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 55 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) |
| 54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 56 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) |
| 55 | 57 | ||
| @@ -67,48 +69,54 @@ class CSVDataset(Dataset): | |||
| 67 | size=512, | 69 | size=512, |
| 68 | repeats=1, | 70 | repeats=1, |
| 69 | interpolation="bicubic", | 71 | interpolation="bicubic", |
| 70 | flip_p=0.5, | ||
| 71 | placeholder_token="*", | 72 | placeholder_token="*", |
| 72 | center_crop=False, | 73 | center_crop=False, |
| 74 | batch_size=1, | ||
| 73 | ): | 75 | ): |
| 74 | 76 | ||
| 75 | self.data = data | 77 | self.data = data |
| 76 | self.tokenizer = tokenizer | 78 | self.tokenizer = tokenizer |
| 77 | |||
| 78 | self.num_images = len(self.data) | ||
| 79 | self._length = self.num_images * repeats | ||
| 80 | |||
| 81 | self.placeholder_token = placeholder_token | 79 | self.placeholder_token = placeholder_token |
| 80 | self.batch_size = batch_size | ||
| 81 | self.cache = {} | ||
| 82 | 82 | ||
| 83 | self.size = size | 83 | self.num_instance_images = len(self.data) |
| 84 | self.center_crop = center_crop | 84 | self._length = self.num_instance_images * repeats |
| 85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
| 86 | "bilinear": PIL.Image.BILINEAR, | ||
| 87 | "bicubic": PIL.Image.BICUBIC, | ||
| 88 | "lanczos": PIL.Image.LANCZOS, | ||
| 89 | }[interpolation] | ||
| 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
| 91 | 85 | ||
| 92 | self.cache = {} | 86 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
| 87 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
| 88 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
| 89 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
| 90 | }[interpolation] | ||
| 91 | self.image_transforms = transforms.Compose( | ||
| 92 | [ | ||
| 93 | transforms.Resize(size, interpolation=self.interpolation), | ||
| 94 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
| 95 | transforms.RandomHorizontalFlip(), | ||
| 96 | transforms.ToTensor(), | ||
| 97 | transforms.Normalize([0.5], [0.5]), | ||
| 98 | ] | ||
| 99 | ) | ||
| 93 | 100 | ||
| 94 | def __len__(self): | 101 | def __len__(self): |
| 95 | return self._length | 102 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 96 | 103 | ||
| 97 | def get_example(self, i, flipped): | 104 | def get_example(self, i): |
| 98 | image_path, text = self.data[i % self.num_images] | 105 | image_path, text = self.data[i % self.num_instance_images] |
| 99 | 106 | ||
| 100 | if image_path in self.cache: | 107 | if image_path in self.cache: |
| 101 | return self.cache[image_path] | 108 | return self.cache[image_path] |
| 102 | 109 | ||
| 103 | example = {} | 110 | example = {} |
| 104 | image = Image.open(image_path) | ||
| 105 | 111 | ||
| 106 | if not image.mode == "RGB": | 112 | instance_image = Image.open(image_path) |
| 107 | image = image.convert("RGB") | 113 | if not instance_image.mode == "RGB": |
| 114 | instance_image = instance_image.convert("RGB") | ||
| 108 | 115 | ||
| 109 | text = text.format(self.placeholder_token) | 116 | text = text.format(self.placeholder_token) |
| 110 | 117 | ||
| 111 | example["prompt"] = text | 118 | example["prompts"] = text |
| 119 | example["pixel_values"] = instance_image | ||
| 112 | example["input_ids"] = self.tokenizer( | 120 | example["input_ids"] = self.tokenizer( |
| 113 | text, | 121 | text, |
| 114 | padding="max_length", | 122 | padding="max_length", |
| @@ -117,29 +125,15 @@ class CSVDataset(Dataset): | |||
| 117 | return_tensors="pt", | 125 | return_tensors="pt", |
| 118 | ).input_ids[0] | 126 | ).input_ids[0] |
| 119 | 127 | ||
| 120 | # default to score-sde preprocessing | ||
| 121 | img = np.array(image).astype(np.uint8) | ||
| 122 | |||
| 123 | if self.center_crop: | ||
| 124 | crop = min(img.shape[0], img.shape[1]) | ||
| 125 | h, w, = img.shape[0], img.shape[1] | ||
| 126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
| 127 | (w - crop) // 2:(w + crop) // 2] | ||
| 128 | |||
| 129 | image = Image.fromarray(img) | ||
| 130 | image = image.resize((self.size, self.size), | ||
| 131 | resample=self.interpolation) | ||
| 132 | image = self.flip(image) | ||
| 133 | image = np.array(image).astype(np.uint8) | ||
| 134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
| 135 | |||
| 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
| 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
| 138 | |||
| 139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
| 140 | return example | 129 | return example |
| 141 | 130 | ||
| 142 | def __getitem__(self, i): | 131 | def __getitem__(self, i): |
| 143 | flipped = random.choice([False, True]) | 132 | example = {} |
| 144 | example = self.get_example(i, flipped) | 133 | unprocessed_example = self.get_example(i) |
| 134 | |||
| 135 | example["prompts"] = unprocessed_example["prompts"] | ||
| 136 | example["input_ids"] = unprocessed_example["input_ids"] | ||
| 137 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | ||
| 138 | |||
| 145 | return example | 139 | return example |
