summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py64
1 files changed, 39 insertions, 25 deletions
diff --git a/data/csv.py b/data/csv.py
index 5144c0a..f9b5e39 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,16 +1,20 @@
1import math 1import math
2import pandas as pd
3import torch 2import torch
3import json
4from pathlib import Path 4from pathlib import Path
5import pytorch_lightning as pl 5import pytorch_lightning as pl
6from PIL import Image 6from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
8from torchvision import transforms 8from torchvision import transforms
9from typing import NamedTuple, List, Optional 9from typing import Dict, NamedTuple, List, Optional, Union
10 10
11from models.clip.prompt import PromptProcessor 11from models.clip.prompt import PromptProcessor
12 12
13 13
14def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt
16
17
14class CSVDataItem(NamedTuple): 18class CSVDataItem(NamedTuple):
15 instance_image_path: Path 19 instance_image_path: Path
16 class_image_path: Path 20 class_image_path: Path
@@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule):
60 self.collate_fn = collate_fn 64 self.collate_fn = collate_fn
61 self.batch_size = batch_size 65 self.batch_size = batch_size
62 66
63 def prepare_subdata(self, data, num_class_images=1): 67 def prepare_subdata(self, template, data, num_class_images=1):
68 image = template["image"] if "image" in template else "{}"
69 prompt = template["prompt"] if "prompt" in template else "{content}"
70 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
71
64 image_multiplier = max(math.ceil(num_class_images / len(data)), 1) 72 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
65 73
66 return [ 74 return [
67 CSVDataItem( 75 CSVDataItem(
68 self.data_root.joinpath(item.image), 76 self.data_root.joinpath(image.format(item["image"])),
69 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), 77 self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"),
70 item.prompt, 78 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
71 item.nprompt 79 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else ""))
72 ) 80 )
73 for item in data 81 for item in data
74 for i in range(image_multiplier) 82 for i in range(image_multiplier)
75 ] 83 ]
76 84
77 def prepare_data(self): 85 def prepare_data(self):
78 metadata = pd.read_json(self.data_file) 86 with open(self.data_file, 'rt') as f:
79 metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] 87 metadata = json.load(f)
80 num_images = len(metadata) 88 template = metadata["template"] if "template" in metadata else {}
89 items = metadata["items"] if "items" in metadata else []
90
91 items = [item for item in items if not "skip" in item or item["skip"] != True]
92 num_images = len(items)
81 93
82 valid_set_size = int(num_images * 0.2) 94 valid_set_size = int(num_images * 0.2)
83 if self.valid_set_size: 95 if self.valid_set_size:
@@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule):
85 valid_set_size = max(valid_set_size, 1) 97 valid_set_size = max(valid_set_size, 1)
86 train_set_size = num_images - valid_set_size 98 train_set_size = num_images - valid_set_size
87 99
88 data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) 100 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator)
89 101
90 self.data_train = self.prepare_subdata(data_train, self.num_class_images) 102 self.data_train = self.prepare_subdata(template, data_train, self.num_class_images)
91 self.data_val = self.prepare_subdata(data_val) 103 self.data_val = self.prepare_subdata(template, data_val)
92 104
93 def setup(self, stage=None): 105 def setup(self, stage=None):
94 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 106 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
@@ -133,8 +145,8 @@ class CSVDataset(Dataset):
133 self.instance_identifier = instance_identifier 145 self.instance_identifier = instance_identifier
134 self.class_identifier = class_identifier 146 self.class_identifier = class_identifier
135 self.num_class_images = num_class_images 147 self.num_class_images = num_class_images
136 self.cache = {}
137 self.image_cache = {} 148 self.image_cache = {}
149 self.input_id_cache = {}
138 150
139 self.num_instance_images = len(self.data) 151 self.num_instance_images = len(self.data)
140 self._length = self.num_instance_images * repeats 152 self._length = self.num_instance_images * repeats
@@ -168,12 +180,19 @@ class CSVDataset(Dataset):
168 180
169 return image 181 return image
170 182
183 def get_input_ids(self, prompt, identifier):
184 prompt = prompt.format(identifier)
185
186 if prompt in self.input_id_cache:
187 return self.input_id_cache[prompt]
188
189 input_ids = self.prompt_processor.get_input_ids(prompt)
190 self.input_id_cache[prompt] = input_ids
191
192 return input_ids
193
171 def get_example(self, i): 194 def get_example(self, i):
172 item = self.data[i % self.num_instance_images] 195 item = self.data[i % self.num_instance_images]
173 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
174
175 if cache_key in self.cache:
176 return self.cache[cache_key]
177 196
178 example = {} 197 example = {}
179 198
@@ -181,17 +200,12 @@ class CSVDataset(Dataset):
181 example["nprompts"] = item.nprompt 200 example["nprompts"] = item.nprompt
182 201
183 example["instance_images"] = self.get_image(item.instance_image_path) 202 example["instance_images"] = self.get_image(item.instance_image_path)
184 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 203 example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier)
185 item.prompt.format(self.instance_identifier)
186 )
187 204
188 if self.num_class_images != 0: 205 if self.num_class_images != 0:
189 example["class_images"] = self.get_image(item.class_image_path) 206 example["class_images"] = self.get_image(item.class_image_path)
190 example["class_prompt_ids"] = self.prompt_processor.get_input_ids( 207 example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier)
191 item.nprompt.format(self.class_identifier)
192 )
193 208
194 self.cache[cache_key] = example
195 return example 209 return example
196 210
197 def __getitem__(self, i): 211 def __getitem__(self, i):