summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py144
1 files changed, 88 insertions, 56 deletions
diff --git a/data/csv.py b/data/csv.py
index 4986153..59d6d8d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt 11from data.keywords import prompt_to_keywords, keywords_to_prompt
12 12
13 13
14image_cache: dict[str, Image.Image] = {}
15
16
17def get_image(path):
18 if path in image_cache:
19 return image_cache[path]
20
21 image = Image.open(path)
22 if not image.mode == "RGB":
23 image = image.convert("RGB")
24 image_cache[path] = image
25
26 return image
27
28
14def prepare_prompt(prompt: Union[str, Dict[str, str]]): 29def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 30 return {"content": prompt} if isinstance(prompt, str) else prompt
16 31
17 32
18class CSVDataItem(NamedTuple): 33class VlpnDataItem(NamedTuple):
19 instance_image_path: Path 34 instance_image_path: Path
20 class_image_path: Path 35 class_image_path: Path
21 prompt: list[str] 36 prompt: list[str]
@@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple):
24 collection: list[str] 39 collection: list[str]
25 40
26 41
27class CSVDataModule(): 42class VlpnDataBucket():
43 def __init__(self, width: int, height: int):
44 self.width = width
45 self.height = height
46 self.ratio = width / height
47 self.items: list[VlpnDataItem] = []
48
49
50class VlpnDataModule():
28 def __init__( 51 def __init__(
29 self, 52 self,
30 batch_size: int, 53 batch_size: int,
@@ -36,11 +59,10 @@ class CSVDataModule():
36 repeats: int = 1, 59 repeats: int = 1,
37 dropout: float = 0, 60 dropout: float = 0,
38 interpolation: str = "bicubic", 61 interpolation: str = "bicubic",
39 center_crop: bool = False,
40 template_key: str = "template", 62 template_key: str = "template",
41 valid_set_size: Optional[int] = None, 63 valid_set_size: Optional[int] = None,
42 seed: Optional[int] = None, 64 seed: Optional[int] = None,
43 filter: Optional[Callable[[CSVDataItem], bool]] = None, 65 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
44 collate_fn=None, 66 collate_fn=None,
45 num_workers: int = 0 67 num_workers: int = 0
46 ): 68 ):
@@ -60,7 +82,6 @@ class CSVDataModule():
60 self.size = size 82 self.size = size
61 self.repeats = repeats 83 self.repeats = repeats
62 self.dropout = dropout 84 self.dropout = dropout
63 self.center_crop = center_crop
64 self.template_key = template_key 85 self.template_key = template_key
65 self.interpolation = interpolation 86 self.interpolation = interpolation
66 self.valid_set_size = valid_set_size 87 self.valid_set_size = valid_set_size
@@ -70,14 +91,14 @@ class CSVDataModule():
70 self.num_workers = num_workers 91 self.num_workers = num_workers
71 self.batch_size = batch_size 92 self.batch_size = batch_size
72 93
73 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: 94 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
74 image = template["image"] if "image" in template else "{}" 95 image = template["image"] if "image" in template else "{}"
75 prompt = template["prompt"] if "prompt" in template else "{content}" 96 prompt = template["prompt"] if "prompt" in template else "{content}"
76 cprompt = template["cprompt"] if "cprompt" in template else "{content}" 97 cprompt = template["cprompt"] if "cprompt" in template else "{content}"
77 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 98 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
78 99
79 return [ 100 return [
80 CSVDataItem( 101 VlpnDataItem(
81 self.data_root.joinpath(image.format(item["image"])), 102 self.data_root.joinpath(image.format(item["image"])),
82 None, 103 None,
83 prompt_to_keywords( 104 prompt_to_keywords(
@@ -97,17 +118,17 @@ class CSVDataModule():
97 for item in data 118 for item in data
98 ] 119 ]
99 120
100 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: 121 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]:
101 if self.filter is None: 122 if self.filter is None:
102 return items 123 return items
103 124
104 return [item for item in items if self.filter(item)] 125 return [item for item in items if self.filter(item)]
105 126
106 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 127 def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]:
107 image_multiplier = max(num_class_images, 1) 128 image_multiplier = max(num_class_images, 1)
108 129
109 return [ 130 return [
110 CSVDataItem( 131 VlpnDataItem(
111 item.instance_image_path, 132 item.instance_image_path,
112 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 133 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
113 item.prompt, 134 item.prompt,
@@ -119,7 +140,30 @@ class CSVDataModule():
119 for i in range(image_multiplier) 140 for i in range(image_multiplier)
120 ] 141 ]
121 142
122 def prepare_data(self): 143 def generate_buckets(self, items: list[VlpnDataItem]):
144 buckets = [VlpnDataBucket(self.size, self.size)]
145
146 for i in range(1, 5):
147 s = self.size + i * 64
148 buckets.append(VlpnDataBucket(s, self.size))
149 buckets.append(VlpnDataBucket(self.size, s))
150
151 for item in items:
152 image = get_image(item.instance_image_path)
153 ratio = image.width / image.height
154
155 if ratio >= 1:
156 candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio]
157 else:
158 candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio]
159
160 for bucket in candidates:
161 bucket.items.append(item)
162
163 buckets = [bucket for bucket in buckets if len(bucket.items) != 0]
164 return buckets
165
166 def setup(self):
123 with open(self.data_file, 'rt') as f: 167 with open(self.data_file, 'rt') as f:
124 metadata = json.load(f) 168 metadata = json.load(f)
125 template = metadata[self.template_key] if self.template_key in metadata else {} 169 template = metadata[self.template_key] if self.template_key in metadata else {}
@@ -144,48 +188,48 @@ class CSVDataModule():
144 self.data_train = self.pad_items(data_train, self.num_class_images) 188 self.data_train = self.pad_items(data_train, self.num_class_images)
145 self.data_val = self.pad_items(data_val) 189 self.data_val = self.pad_items(data_val)
146 190
147 def setup(self, stage=None): 191 buckets = self.generate_buckets(data_train)
148 train_dataset = CSVDataset( 192
149 self.data_train, self.prompt_processor, batch_size=self.batch_size, 193 train_datasets = [
150 num_class_images=self.num_class_images, 194 VlpnDataset(
151 size=self.size, interpolation=self.interpolation, 195 bucket.items, self.prompt_processor, batch_size=self.batch_size,
152 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout 196 width=bucket.width, height=bucket.height, interpolation=self.interpolation,
153 ) 197 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout,
154 val_dataset = CSVDataset( 198 )
155 self.data_val, self.prompt_processor, batch_size=self.batch_size, 199 for bucket in buckets
156 size=self.size, interpolation=self.interpolation, 200 ]
157 center_crop=self.center_crop 201
158 ) 202 val_dataset = VlpnDataset(
159 self.train_dataloader_ = DataLoader( 203 data_val, self.prompt_processor, batch_size=self.batch_size,
160 train_dataset, batch_size=self.batch_size, 204 width=self.size, height=self.size, interpolation=self.interpolation,
161 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
162 num_workers=self.num_workers
163 )
164 self.val_dataloader_ = DataLoader(
165 val_dataset, batch_size=self.batch_size,
166 pin_memory=True, collate_fn=self.collate_fn,
167 num_workers=self.num_workers
168 ) 205 )
169 206
170 def train_dataloader(self): 207 self.train_dataloaders = [
171 return self.train_dataloader_ 208 DataLoader(
209 dataset, batch_size=self.batch_size, shuffle=True,
210 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
211 )
212 for dataset in train_datasets
213 ]
172 214
173 def val_dataloader(self): 215 self.val_dataloader = DataLoader(
174 return self.val_dataloader_ 216 val_dataset, batch_size=self.batch_size,
217 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
218 )
175 219
176 220
177class CSVDataset(Dataset): 221class VlpnDataset(Dataset):
178 def __init__( 222 def __init__(
179 self, 223 self,
180 data: List[CSVDataItem], 224 data: List[VlpnDataItem],
181 prompt_processor: PromptProcessor, 225 prompt_processor: PromptProcessor,
182 batch_size: int = 1, 226 batch_size: int = 1,
183 num_class_images: int = 0, 227 num_class_images: int = 0,
184 size: int = 768, 228 width: int = 768,
229 height: int = 768,
185 repeats: int = 1, 230 repeats: int = 1,
186 dropout: float = 0, 231 dropout: float = 0,
187 interpolation: str = "bicubic", 232 interpolation: str = "bicubic",
188 center_crop: bool = False,
189 ): 233 ):
190 234
191 self.data = data 235 self.data = data
@@ -193,7 +237,6 @@ class CSVDataset(Dataset):
193 self.batch_size = batch_size 237 self.batch_size = batch_size
194 self.num_class_images = num_class_images 238 self.num_class_images = num_class_images
195 self.dropout = dropout 239 self.dropout = dropout
196 self.image_cache = {}
197 240
198 self.num_instance_images = len(self.data) 241 self.num_instance_images = len(self.data)
199 self._length = self.num_instance_images * repeats 242 self._length = self.num_instance_images * repeats
@@ -206,8 +249,8 @@ class CSVDataset(Dataset):
206 }[interpolation] 249 }[interpolation]
207 self.image_transforms = transforms.Compose( 250 self.image_transforms = transforms.Compose(
208 [ 251 [
209 transforms.Resize(size, interpolation=self.interpolation), 252 transforms.Resize(min(width, height), interpolation=self.interpolation),
210 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 253 transforms.RandomCrop((height, width)),
211 transforms.RandomHorizontalFlip(), 254 transforms.RandomHorizontalFlip(),
212 transforms.ToTensor(), 255 transforms.ToTensor(),
213 transforms.Normalize([0.5], [0.5]), 256 transforms.Normalize([0.5], [0.5]),
@@ -217,17 +260,6 @@ class CSVDataset(Dataset):
217 def __len__(self): 260 def __len__(self):
218 return math.ceil(self._length / self.batch_size) * self.batch_size 261 return math.ceil(self._length / self.batch_size) * self.batch_size
219 262
220 def get_image(self, path):
221 if path in self.image_cache:
222 return self.image_cache[path]
223
224 image = Image.open(path)
225 if not image.mode == "RGB":
226 image = image.convert("RGB")
227 self.image_cache[path] = image
228
229 return image
230
231 def get_example(self, i): 263 def get_example(self, i):
232 item = self.data[i % self.num_instance_images] 264 item = self.data[i % self.num_instance_images]
233 265
@@ -235,9 +267,9 @@ class CSVDataset(Dataset):
235 example["prompts"] = item.prompt 267 example["prompts"] = item.prompt
236 example["cprompts"] = item.cprompt 268 example["cprompts"] = item.cprompt
237 example["nprompts"] = item.nprompt 269 example["nprompts"] = item.nprompt
238 example["instance_images"] = self.get_image(item.instance_image_path) 270 example["instance_images"] = get_image(item.instance_image_path)
239 if self.num_class_images != 0: 271 if self.num_class_images != 0:
240 example["class_images"] = self.get_image(item.class_image_path) 272 example["class_images"] = get_image(item.class_image_path)
241 273
242 return example 274 return example
243 275