summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py273
1 files changed, 154 insertions, 119 deletions
diff --git a/data/csv.py b/data/csv.py
index 654aec1..9be36ba 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,20 +2,28 @@ import math
2import torch 2import torch
3import json 3import json
4from pathlib import Path 4from pathlib import Path
5from typing import NamedTuple, Optional, Union, Callable
6
5from PIL import Image 7from PIL import Image
6from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms
8from typing import Dict, NamedTuple, List, Optional, Union, Callable
9 8
10import numpy as np 9from torch.utils.data import IterableDataset, DataLoader, random_split
10from torchvision import transforms
11 11
12from models.clip.prompt import PromptProcessor
13from data.keywords import prompt_to_keywords, keywords_to_prompt 12from data.keywords import prompt_to_keywords, keywords_to_prompt
13from models.clip.prompt import PromptProcessor
14 14
15 15
16image_cache: dict[str, Image.Image] = {} 16image_cache: dict[str, Image.Image] = {}
17 17
18 18
19interpolations = {
20 "linear": transforms.InterpolationMode.NEAREST,
21 "bilinear": transforms.InterpolationMode.BILINEAR,
22 "bicubic": transforms.InterpolationMode.BICUBIC,
23 "lanczos": transforms.InterpolationMode.LANCZOS,
24}
25
26
19def get_image(path): 27def get_image(path):
20 if path in image_cache: 28 if path in image_cache:
21 return image_cache[path] 29 return image_cache[path]
@@ -28,10 +36,46 @@ def get_image(path):
28 return image 36 return image
29 37
30 38
31def prepare_prompt(prompt: Union[str, Dict[str, str]]): 39def prepare_prompt(prompt: Union[str, dict[str, str]]):
32 return {"content": prompt} if isinstance(prompt, str) else prompt 40 return {"content": prompt} if isinstance(prompt, str) else prompt
33 41
34 42
43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool):
44 item_order: list[int] = []
45 item_buckets: list[int] = []
46 buckets = [1.0]
47
48 for i in range(1, num_buckets + 1):
49 s = size + i * 64
50 buckets.append(s / size)
51 buckets.append(size / s)
52
53 buckets = torch.tensor(buckets)
54 bucket_indices = torch.arange(len(buckets))
55
56 for i, item in enumerate(items):
57 image = get_image(item)
58 ratio = image.width / image.height
59
60 if ratio >= 1:
61 mask = torch.bitwise_and(buckets >= 1, buckets <= ratio)
62 else:
63 mask = torch.bitwise_and(buckets <= 1, buckets >= ratio)
64
65 if not progressive_buckets:
66 mask = (buckets + (~mask) * math.inf - ratio).abs().argmin()
67
68 indices = bucket_indices[mask]
69
70 if len(indices.shape) == 0:
71 indices = indices.unsqueeze(0)
72
73 item_order += [i] * len(indices)
74 item_buckets += indices
75
76 return buckets.tolist(), item_order, item_buckets
77
78
35class VlpnDataItem(NamedTuple): 79class VlpnDataItem(NamedTuple):
36 instance_image_path: Path 80 instance_image_path: Path
37 class_image_path: Path 81 class_image_path: Path
@@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple):
41 collection: list[str] 85 collection: list[str]
42 86
43 87
44class VlpnDataBucket():
45 def __init__(self, width: int, height: int):
46 self.width = width
47 self.height = height
48 self.ratio = width / height
49 self.items: list[VlpnDataItem] = []
50
51
52class VlpnDataModule(): 88class VlpnDataModule():
53 def __init__( 89 def __init__(
54 self, 90 self,
@@ -60,7 +96,6 @@ class VlpnDataModule():
60 size: int = 768, 96 size: int = 768,
61 num_aspect_ratio_buckets: int = 0, 97 num_aspect_ratio_buckets: int = 0,
62 progressive_aspect_ratio_buckets: bool = False, 98 progressive_aspect_ratio_buckets: bool = False,
63 repeats: int = 1,
64 dropout: float = 0, 99 dropout: float = 0,
65 interpolation: str = "bicubic", 100 interpolation: str = "bicubic",
66 template_key: str = "template", 101 template_key: str = "template",
@@ -86,7 +121,6 @@ class VlpnDataModule():
86 self.size = size 121 self.size = size
87 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets 122 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets
88 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets 123 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets
89 self.repeats = repeats
90 self.dropout = dropout 124 self.dropout = dropout
91 self.template_key = template_key 125 self.template_key = template_key
92 self.interpolation = interpolation 126 self.interpolation = interpolation
@@ -146,36 +180,6 @@ class VlpnDataModule():
146 for i in range(image_multiplier) 180 for i in range(image_multiplier)
147 ] 181 ]
148 182
149 def generate_buckets(self, items: list[VlpnDataItem]):
150 buckets = [VlpnDataBucket(self.size, self.size)]
151
152 for i in range(1, self.num_aspect_ratio_buckets + 1):
153 s = self.size + i * 64
154 buckets.append(VlpnDataBucket(s, self.size))
155 buckets.append(VlpnDataBucket(self.size, s))
156
157 buckets = np.array(buckets)
158 bucket_ratios = np.array([bucket.ratio for bucket in buckets])
159
160 for item in items:
161 image = get_image(item.instance_image_path)
162 ratio = image.width / image.height
163
164 if ratio >= 1:
165 mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio)
166 else:
167 mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio)
168
169 if not self.progressive_aspect_ratio_buckets:
170 ratios = bucket_ratios.copy()
171 ratios[~mask] = math.inf
172 mask = [np.argmin(np.abs(ratios - ratio))]
173
174 for bucket in buckets[mask]:
175 bucket.items.append(item)
176
177 return [bucket for bucket in buckets if len(bucket.items) != 0]
178
179 def setup(self): 183 def setup(self):
180 with open(self.data_file, 'rt') as f: 184 with open(self.data_file, 'rt') as f:
181 metadata = json.load(f) 185 metadata = json.load(f)
@@ -201,105 +205,136 @@ class VlpnDataModule():
201 self.data_train = self.pad_items(data_train, self.num_class_images) 205 self.data_train = self.pad_items(data_train, self.num_class_images)
202 self.data_val = self.pad_items(data_val) 206 self.data_val = self.pad_items(data_val)
203 207
204 buckets = self.generate_buckets(data_train) 208 train_dataset = VlpnDataset(
205 209 self.data_train, self.prompt_processor,
206 train_datasets = [ 210 num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets,
207 VlpnDataset( 211 batch_size=self.batch_size,
208 bucket.items, self.prompt_processor, 212 size=self.size, interpolation=self.interpolation,
209 width=bucket.width, height=bucket.height, interpolation=self.interpolation, 213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
210 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, 214 )
211 )
212 for bucket in buckets
213 ]
214 215
215 val_dataset = VlpnDataset( 216 val_dataset = VlpnDataset(
216 data_val, self.prompt_processor, 217 self.data_val, self.prompt_processor,
217 width=self.size, height=self.size, interpolation=self.interpolation, 218 batch_size=self.batch_size,
219 size=self.size, interpolation=self.interpolation,
218 ) 220 )
219 221
220 self.train_dataloaders = [ 222 self.train_dataloader = DataLoader(
221 DataLoader( 223 train_dataset,
222 dataset, batch_size=self.batch_size, shuffle=True, 224 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
223 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 225 )
224 )
225 for dataset in train_datasets
226 ]
227 226
228 self.val_dataloader = DataLoader( 227 self.val_dataloader = DataLoader(
229 val_dataset, batch_size=self.batch_size, 228 val_dataset,
230 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 229 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
231 ) 230 )
232 231
233 232
234class VlpnDataset(Dataset): 233class VlpnDataset(IterableDataset):
235 def __init__( 234 def __init__(
236 self, 235 self,
237 data: List[VlpnDataItem], 236 items: list[VlpnDataItem],
238 prompt_processor: PromptProcessor, 237 prompt_processor: PromptProcessor,
238 num_buckets: int = 1,
239 progressive_buckets: bool = False,
240 batch_size: int = 1,
239 num_class_images: int = 0, 241 num_class_images: int = 0,
240 width: int = 768, 242 size: int = 768,
241 height: int = 768,
242 repeats: int = 1,
243 dropout: float = 0, 243 dropout: float = 0,
244 shuffle: bool = False,
244 interpolation: str = "bicubic", 245 interpolation: str = "bicubic",
246 generator: Optional[torch.Generator] = None,
245 ): 247 ):
248 self.items = items
249 self.batch_size = batch_size
246 250
247 self.data = data
248 self.prompt_processor = prompt_processor 251 self.prompt_processor = prompt_processor
249 self.num_class_images = num_class_images 252 self.num_class_images = num_class_images
253 self.size = size
250 self.dropout = dropout 254 self.dropout = dropout
251 255 self.shuffle = shuffle
252 self.num_instance_images = len(self.data) 256 self.interpolation = interpolations[interpolation]
253 self._length = self.num_instance_images * repeats 257 self.generator = generator
254 258
255 self.interpolation = { 259 buckets, item_order, item_buckets = generate_buckets(
256 "linear": transforms.InterpolationMode.NEAREST, 260 [item.instance_image_path for item in items],
257 "bilinear": transforms.InterpolationMode.BILINEAR, 261 size,
258 "bicubic": transforms.InterpolationMode.BICUBIC, 262 num_buckets,
259 "lanczos": transforms.InterpolationMode.LANCZOS, 263 progressive_buckets
260 }[interpolation]
261 self.image_transforms = transforms.Compose(
262 [
263 transforms.Resize(min(width, height), interpolation=self.interpolation),
264 transforms.RandomCrop((height, width)),
265 transforms.RandomHorizontalFlip(),
266 transforms.ToTensor(),
267 transforms.Normalize([0.5], [0.5]),
268 ]
269 ) 264 )
270 265
271 def __len__(self): 266 self.buckets = torch.tensor(buckets)
272 return self._length 267 self.item_order = torch.tensor(item_order)
268 self.item_buckets = torch.tensor(item_buckets)
273 269
274 def get_example(self, i): 270 def __len__(self):
275 item = self.data[i % self.num_instance_images] 271 return len(self.item_buckets)
276 272
277 example = {} 273 def __iter__(self):
278 example["prompts"] = item.prompt 274 worker_info = torch.utils.data.get_worker_info()
279 example["cprompts"] = item.cprompt 275
280 example["nprompts"] = item.nprompt 276 if self.shuffle:
281 example["instance_images"] = get_image(item.instance_image_path) 277 perm = torch.randperm(len(self.item_buckets), generator=self.generator)
282 if self.num_class_images != 0: 278 self.item_order = self.item_order[perm]
283 example["class_images"] = get_image(item.class_image_path) 279 self.item_buckets = self.item_buckets[perm]
284 280
285 return example 281 item_mask = torch.ones_like(self.item_buckets, dtype=bool)
282 bucket = -1
283 image_transforms = None
284 batch = []
285 batch_size = self.batch_size
286
287 if worker_info is not None:
288 batch_size = math.ceil(batch_size / worker_info.num_workers)
289 worker_batch = math.ceil(len(self) / worker_info.num_workers)
290 start = worker_info.id * worker_batch
291 end = start + worker_batch
292 item_mask[:start] = False
293 item_mask[end:] = False
294
295 while item_mask.any():
296 item_indices = self.item_order[(self.item_buckets == bucket) & item_mask]
297
298 if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0):
299 yield batch
300 batch = []
301
302 if len(item_indices) == 0:
303 bucket = self.item_buckets[item_mask][0]
304 ratio = self.buckets[bucket]
305 width = self.size * ratio if ratio > 1 else self.size
306 height = self.size / ratio if ratio < 1 else self.size
307
308 image_transforms = transforms.Compose(
309 [
310 transforms.Resize(min(width, height), interpolation=self.interpolation),
311 transforms.RandomCrop((height, width)),
312 transforms.RandomHorizontalFlip(),
313 transforms.ToTensor(),
314 transforms.Normalize([0.5], [0.5]),
315 ]
316 )
317 else:
318 item_index = item_indices[0]
319 item = self.items[item_index]
320 item_mask[item_index] = False
286 321
287 def __getitem__(self, i): 322 example = {}
288 unprocessed_example = self.get_example(i)
289 323
290 example = {} 324 example["prompts"] = keywords_to_prompt(item.prompt)
325 example["cprompts"] = item.cprompt
326 example["nprompts"] = item.nprompt
291 327
292 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) 328 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
293 example["cprompts"] = unprocessed_example["cprompts"] 329 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
294 example["nprompts"] = unprocessed_example["nprompts"] 330 keywords_to_prompt(item.prompt, self.dropout, True)
331 )
295 332
296 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 333 if self.num_class_images != 0:
297 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 334 example["class_images"] = image_transforms(get_image(item.class_image_path))
298 keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) 335 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
299 )
300 336
301 if self.num_class_images != 0: 337 batch.append(example)
302 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
303 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
304 338
305 return example 339 if len(batch) != 0:
340 yield batch