diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 64 |
1 files changed, 39 insertions, 25 deletions
diff --git a/data/csv.py b/data/csv.py index 5144c0a..f9b5e39 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,16 +1,20 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import pandas as pd | ||
| 3 | import torch | 2 | import torch |
| 3 | import json | ||
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
| 6 | from PIL import Image | 6 | from PIL import Image |
| 7 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 8 | from torchvision import transforms | 8 | from torchvision import transforms |
| 9 | from typing import NamedTuple, List, Optional | 9 | from typing import Dict, NamedTuple, List, Optional, Union |
| 10 | 10 | ||
| 11 | from models.clip.prompt import PromptProcessor | 11 | from models.clip.prompt import PromptProcessor |
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | ||
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | ||
| 16 | |||
| 17 | |||
| 14 | class CSVDataItem(NamedTuple): | 18 | class CSVDataItem(NamedTuple): |
| 15 | instance_image_path: Path | 19 | instance_image_path: Path |
| 16 | class_image_path: Path | 20 | class_image_path: Path |
| @@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 60 | self.collate_fn = collate_fn | 64 | self.collate_fn = collate_fn |
| 61 | self.batch_size = batch_size | 65 | self.batch_size = batch_size |
| 62 | 66 | ||
| 63 | def prepare_subdata(self, data, num_class_images=1): | 67 | def prepare_subdata(self, template, data, num_class_images=1): |
| 68 | image = template["image"] if "image" in template else "{}" | ||
| 69 | prompt = template["prompt"] if "prompt" in template else "{content}" | ||
| 70 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | ||
| 71 | |||
| 64 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | 72 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) |
| 65 | 73 | ||
| 66 | return [ | 74 | return [ |
| 67 | CSVDataItem( | 75 | CSVDataItem( |
| 68 | self.data_root.joinpath(item.image), | 76 | self.data_root.joinpath(image.format(item["image"])), |
| 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 77 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), |
| 70 | item.prompt, | 78 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 71 | item.nprompt | 79 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
| 72 | ) | 80 | ) |
| 73 | for item in data | 81 | for item in data |
| 74 | for i in range(image_multiplier) | 82 | for i in range(image_multiplier) |
| 75 | ] | 83 | ] |
| 76 | 84 | ||
| 77 | def prepare_data(self): | 85 | def prepare_data(self): |
| 78 | metadata = pd.read_json(self.data_file) | 86 | with open(self.data_file, 'rt') as f: |
| 79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] | 87 | metadata = json.load(f) |
| 80 | num_images = len(metadata) | 88 | template = metadata["template"] if "template" in metadata else {} |
| 89 | items = metadata["items"] if "items" in metadata else [] | ||
| 90 | |||
| 91 | items = [item for item in items if not "skip" in item or item["skip"] != True] | ||
| 92 | num_images = len(items) | ||
| 81 | 93 | ||
| 82 | valid_set_size = int(num_images * 0.2) | 94 | valid_set_size = int(num_images * 0.2) |
| 83 | if self.valid_set_size: | 95 | if self.valid_set_size: |
| @@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 85 | valid_set_size = max(valid_set_size, 1) | 97 | valid_set_size = max(valid_set_size, 1) |
| 86 | train_set_size = num_images - valid_set_size | 98 | train_set_size = num_images - valid_set_size |
| 87 | 99 | ||
| 88 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) | 100 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
| 89 | 101 | ||
| 90 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) | 102 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) |
| 91 | self.data_val = self.prepare_subdata(data_val) | 103 | self.data_val = self.prepare_subdata(template, data_val) |
| 92 | 104 | ||
| 93 | def setup(self, stage=None): | 105 | def setup(self, stage=None): |
| 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 106 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| @@ -133,8 +145,8 @@ class CSVDataset(Dataset): | |||
| 133 | self.instance_identifier = instance_identifier | 145 | self.instance_identifier = instance_identifier |
| 134 | self.class_identifier = class_identifier | 146 | self.class_identifier = class_identifier |
| 135 | self.num_class_images = num_class_images | 147 | self.num_class_images = num_class_images |
| 136 | self.cache = {} | ||
| 137 | self.image_cache = {} | 148 | self.image_cache = {} |
| 149 | self.input_id_cache = {} | ||
| 138 | 150 | ||
| 139 | self.num_instance_images = len(self.data) | 151 | self.num_instance_images = len(self.data) |
| 140 | self._length = self.num_instance_images * repeats | 152 | self._length = self.num_instance_images * repeats |
| @@ -168,12 +180,19 @@ class CSVDataset(Dataset): | |||
| 168 | 180 | ||
| 169 | return image | 181 | return image |
| 170 | 182 | ||
| 183 | def get_input_ids(self, prompt, identifier): | ||
| 184 | prompt = prompt.format(identifier) | ||
| 185 | |||
| 186 | if prompt in self.input_id_cache: | ||
| 187 | return self.input_id_cache[prompt] | ||
| 188 | |||
| 189 | input_ids = self.prompt_processor.get_input_ids(prompt) | ||
| 190 | self.input_id_cache[prompt] = input_ids | ||
| 191 | |||
| 192 | return input_ids | ||
| 193 | |||
| 171 | def get_example(self, i): | 194 | def get_example(self, i): |
| 172 | item = self.data[i % self.num_instance_images] | 195 | item = self.data[i % self.num_instance_images] |
| 173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | ||
| 174 | |||
| 175 | if cache_key in self.cache: | ||
| 176 | return self.cache[cache_key] | ||
| 177 | 196 | ||
| 178 | example = {} | 197 | example = {} |
| 179 | 198 | ||
| @@ -181,17 +200,12 @@ class CSVDataset(Dataset): | |||
| 181 | example["nprompts"] = item.nprompt | 200 | example["nprompts"] = item.nprompt |
| 182 | 201 | ||
| 183 | example["instance_images"] = self.get_image(item.instance_image_path) | 202 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 203 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) |
| 185 | item.prompt.format(self.instance_identifier) | ||
| 186 | ) | ||
| 187 | 204 | ||
| 188 | if self.num_class_images != 0: | 205 | if self.num_class_images != 0: |
| 189 | example["class_images"] = self.get_image(item.class_image_path) | 206 | example["class_images"] = self.get_image(item.class_image_path) |
| 190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 207 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) |
| 191 | item.nprompt.format(self.class_identifier) | ||
| 192 | ) | ||
| 193 | 208 | ||
| 194 | self.cache[cache_key] = example | ||
| 195 | return example | 209 | return example |
| 196 | 210 | ||
| 197 | def __getitem__(self, i): | 211 | def __getitem__(self, i): |
