diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 162 |
1 files changed, 85 insertions, 77 deletions
diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,27 +1,38 @@ | |||
| 1 | import math | ||
| 1 | import pandas as pd | 2 | import pandas as pd |
| 2 | from pathlib import Path | 3 | from pathlib import Path |
| 3 | import pytorch_lightning as pl | 4 | import pytorch_lightning as pl |
| 4 | from PIL import Image | 5 | from PIL import Image |
| 5 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
| 6 | from torchvision import transforms | 7 | from torchvision import transforms |
| 8 | from typing import NamedTuple, List | ||
| 9 | |||
| 10 | |||
| 11 | class CSVDataItem(NamedTuple): | ||
| 12 | instance_image_path: Path | ||
| 13 | class_image_path: Path | ||
| 14 | prompt: str | ||
| 15 | nprompt: str | ||
| 7 | 16 | ||
| 8 | 17 | ||
| 9 | class CSVDataModule(pl.LightningDataModule): | 18 | class CSVDataModule(pl.LightningDataModule): |
| 10 | def __init__(self, | 19 | def __init__( |
| 11 | batch_size, | 20 | self, |
| 12 | data_file, | 21 | batch_size, |
| 13 | tokenizer, | 22 | data_file, |
| 14 | instance_identifier, | 23 | tokenizer, |
| 15 | class_identifier=None, | 24 | instance_identifier, |
| 16 | class_subdir="db_cls", | 25 | class_identifier=None, |
| 17 | num_class_images=2, | 26 | class_subdir="db_cls", |
| 18 | size=512, | 27 | num_class_images=100, |
| 19 | repeats=100, | 28 | size=512, |
| 20 | interpolation="bicubic", | 29 | repeats=100, |
| 21 | center_crop=False, | 30 | interpolation="bicubic", |
| 22 | valid_set_size=None, | 31 | center_crop=False, |
| 23 | generator=None, | 32 | valid_set_size=None, |
| 24 | collate_fn=None): | 33 | generator=None, |
| 34 | collate_fn=None | ||
| 35 | ): | ||
| 25 | super().__init__() | 36 | super().__init__() |
| 26 | 37 | ||
| 27 | self.data_file = Path(data_file) | 38 | self.data_file = Path(data_file) |
| @@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 46 | self.collate_fn = collate_fn | 57 | self.collate_fn = collate_fn |
| 47 | self.batch_size = batch_size | 58 | self.batch_size = batch_size |
| 48 | 59 | ||
| 60 | def prepare_subdata(self, data, num_class_images=1): | ||
| 61 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
| 62 | |||
| 63 | return [ | ||
| 64 | CSVDataItem( | ||
| 65 | self.data_root.joinpath(item.image), | ||
| 66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | ||
| 67 | item.prompt, | ||
| 68 | item.nprompt if "nprompt" in item else "" | ||
| 69 | ) | ||
| 70 | for item in data | ||
| 71 | if "skip" not in item or item.skip != "x" | ||
| 72 | for i in range(image_multiplier) | ||
| 73 | ] | ||
| 74 | |||
| 49 | def prepare_data(self): | 75 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 76 | metadata = pd.read_csv(self.data_file) |
| 51 | instance_image_paths = [ | 77 | metadata = list(metadata.itertuples()) |
| 52 | self.data_root.joinpath(f) | 78 | num_images = len(metadata) |
| 53 | for f in metadata['image'].values | ||
| 54 | for i in range(self.num_class_images) | ||
| 55 | ] | ||
| 56 | class_image_paths = [ | ||
| 57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") | ||
| 58 | for f in metadata['image'].values | ||
| 59 | for i in range(self.num_class_images) | ||
| 60 | ] | ||
| 61 | prompts = [ | ||
| 62 | prompt | ||
| 63 | for prompt in metadata['prompt'].values | ||
| 64 | for i in range(self.num_class_images) | ||
| 65 | ] | ||
| 66 | nprompts = [ | ||
| 67 | nprompt | ||
| 68 | for nprompt in metadata['nprompt'].values | ||
| 69 | for i in range(self.num_class_images) | ||
| 70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
| 71 | skips = [ | ||
| 72 | skip | ||
| 73 | for skip in metadata['skip'].values | ||
| 74 | for i in range(self.num_class_images) | ||
| 75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
| 76 | self.data = [ | ||
| 77 | (i, c, p, n) | ||
| 78 | for i, c, p, n, s | ||
| 79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
| 80 | if s != "x" | ||
| 81 | ] | ||
| 82 | 79 | ||
| 83 | def setup(self, stage=None): | 80 | valid_set_size = int(num_images * 0.2) |
| 84 | valid_set_size = int(len(self.data) * 0.2) | ||
| 85 | if self.valid_set_size: | 81 | if self.valid_set_size: |
| 86 | valid_set_size = min(valid_set_size, self.valid_set_size) | 82 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 87 | valid_set_size = max(valid_set_size, 1) | 83 | valid_set_size = max(valid_set_size, 1) |
| 88 | train_set_size = len(self.data) - valid_set_size | 84 | train_set_size = num_images - valid_set_size |
| 89 | 85 | ||
| 90 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) | 86 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) |
| 91 | 87 | ||
| 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 88 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) |
| 89 | self.data_val = self.prepare_subdata(data_val) | ||
| 90 | |||
| 91 | def setup(self, stage=None): | ||
| 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | ||
| 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 94 | num_class_images=self.num_class_images, | 94 | num_class_images=self.num_class_images, |
| 95 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
| 96 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
| 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, |
| 98 | instance_identifier=self.instance_identifier, | 98 | instance_identifier=self.instance_identifier, |
| 99 | size=self.size, interpolation=self.interpolation, | 99 | size=self.size, interpolation=self.interpolation, |
| 100 | center_crop=self.center_crop, repeats=self.repeats) | 100 | center_crop=self.center_crop, repeats=self.repeats) |
| 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, | 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 104 | pin_memory=True, collate_fn=self.collate_fn) | 104 | pin_memory=True, collate_fn=self.collate_fn) |
| 105 | 105 | ||
| 106 | def train_dataloader(self): | 106 | def train_dataloader(self): |
| @@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 111 | 111 | ||
| 112 | 112 | ||
| 113 | class CSVDataset(Dataset): | 113 | class CSVDataset(Dataset): |
| 114 | def __init__(self, | 114 | def __init__( |
| 115 | data, | 115 | self, |
| 116 | tokenizer, | 116 | data: List[CSVDataItem], |
| 117 | instance_identifier, | 117 | tokenizer, |
| 118 | class_identifier=None, | 118 | instance_identifier, |
| 119 | num_class_images=2, | 119 | batch_size=1, |
| 120 | size=512, | 120 | class_identifier=None, |
| 121 | repeats=1, | 121 | num_class_images=0, |
| 122 | interpolation="bicubic", | 122 | size=512, |
| 123 | center_crop=False, | 123 | repeats=1, |
| 124 | ): | 124 | interpolation="bicubic", |
| 125 | center_crop=False, | ||
| 126 | ): | ||
| 125 | 127 | ||
| 126 | self.data = data | 128 | self.data = data |
| 127 | self.tokenizer = tokenizer | 129 | self.tokenizer = tokenizer |
| 130 | self.batch_size = batch_size | ||
| 128 | self.instance_identifier = instance_identifier | 131 | self.instance_identifier = instance_identifier |
| 129 | self.class_identifier = class_identifier | 132 | self.class_identifier = class_identifier |
| 130 | self.num_class_images = num_class_images | 133 | self.num_class_images = num_class_images |
| 131 | self.cache = {} | 134 | self.cache = {} |
| 135 | self.image_cache = {} | ||
| 132 | 136 | ||
| 133 | self.num_instance_images = len(self.data) | 137 | self.num_instance_images = len(self.data) |
| 134 | self._length = self.num_instance_images * repeats | 138 | self._length = self.num_instance_images * repeats |
| @@ -149,46 +153,50 @@ class CSVDataset(Dataset): | |||
| 149 | ) | 153 | ) |
| 150 | 154 | ||
| 151 | def __len__(self): | 155 | def __len__(self): |
| 152 | return self._length | 156 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 153 | 157 | ||
| 154 | def get_example(self, i): | 158 | def get_example(self, i): |
| 155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 159 | item = self.data[i % self.num_instance_images] |
| 156 | cache_key = f"{instance_image_path}_{class_image_path}" | 160 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
| 157 | 161 | ||
| 158 | if cache_key in self.cache: | 162 | if cache_key in self.cache: |
| 159 | return self.cache[cache_key] | 163 | return self.cache[cache_key] |
| 160 | 164 | ||
| 161 | example = {} | 165 | example = {} |
| 162 | 166 | ||
| 163 | example["prompts"] = prompt | 167 | example["prompts"] = item.prompt |
| 164 | example["nprompts"] = nprompt | 168 | example["nprompts"] = item.nprompt |
| 165 | 169 | ||
| 166 | instance_image = Image.open(instance_image_path) | 170 | if item.instance_image_path in self.image_cache: |
| 167 | if not instance_image.mode == "RGB": | 171 | instance_image = self.image_cache[item.instance_image_path] |
| 168 | instance_image = instance_image.convert("RGB") | 172 | else: |
| 173 | instance_image = Image.open(item.instance_image_path) | ||
| 174 | if not instance_image.mode == "RGB": | ||
| 175 | instance_image = instance_image.convert("RGB") | ||
| 176 | self.image_cache[item.instance_image_path] = instance_image | ||
| 169 | 177 | ||
| 170 | example["instance_images"] = instance_image | 178 | example["instance_images"] = instance_image |
| 171 | example["instance_prompt_ids"] = self.tokenizer( | 179 | example["instance_prompt_ids"] = self.tokenizer( |
| 172 | prompt.format(self.instance_identifier), | 180 | item.prompt.format(self.instance_identifier), |
| 173 | padding="do_not_pad", | 181 | padding="do_not_pad", |
| 174 | truncation=True, | 182 | truncation=True, |
| 175 | max_length=self.tokenizer.model_max_length, | 183 | max_length=self.tokenizer.model_max_length, |
| 176 | ).input_ids | 184 | ).input_ids |
| 177 | 185 | ||
| 178 | if self.num_class_images != 0: | 186 | if self.num_class_images != 0: |
| 179 | class_image = Image.open(class_image_path) | 187 | class_image = Image.open(item.class_image_path) |
| 180 | if not class_image.mode == "RGB": | 188 | if not class_image.mode == "RGB": |
| 181 | class_image = class_image.convert("RGB") | 189 | class_image = class_image.convert("RGB") |
| 182 | 190 | ||
| 183 | example["class_images"] = class_image | 191 | example["class_images"] = class_image |
| 184 | example["class_prompt_ids"] = self.tokenizer( | 192 | example["class_prompt_ids"] = self.tokenizer( |
| 185 | prompt.format(self.class_identifier), | 193 | item.prompt.format(self.class_identifier), |
| 186 | padding="do_not_pad", | 194 | padding="do_not_pad", |
| 187 | truncation=True, | 195 | truncation=True, |
| 188 | max_length=self.tokenizer.model_max_length, | 196 | max_length=self.tokenizer.model_max_length, |
| 189 | ).input_ids | 197 | ).input_ids |
| 190 | 198 | ||
| 191 | self.cache[instance_image_path] = example | 199 | self.cache[item.instance_image_path] = example |
| 192 | return example | 200 | return example |
| 193 | 201 | ||
| 194 | def __getitem__(self, i): | 202 | def __getitem__(self, i): |
