diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 273 |
1 files changed, 154 insertions, 119 deletions
diff --git a/data/csv.py b/data/csv.py index 654aec1..9be36ba 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -2,20 +2,28 @@ import math | |||
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | from typing import NamedTuple, Optional, Union, Callable | ||
6 | |||
5 | from PIL import Image | 7 | from PIL import Image |
6 | from torch.utils.data import Dataset, DataLoader, random_split | ||
7 | from torchvision import transforms | ||
8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | ||
9 | 8 | ||
10 | import numpy as np | 9 | from torch.utils.data import IterableDataset, DataLoader, random_split |
10 | from torchvision import transforms | ||
11 | 11 | ||
12 | from models.clip.prompt import PromptProcessor | ||
13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 12 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
13 | from models.clip.prompt import PromptProcessor | ||
14 | 14 | ||
15 | 15 | ||
16 | image_cache: dict[str, Image.Image] = {} | 16 | image_cache: dict[str, Image.Image] = {} |
17 | 17 | ||
18 | 18 | ||
19 | interpolations = { | ||
20 | "linear": transforms.InterpolationMode.NEAREST, | ||
21 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
22 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
23 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
24 | } | ||
25 | |||
26 | |||
19 | def get_image(path): | 27 | def get_image(path): |
20 | if path in image_cache: | 28 | if path in image_cache: |
21 | return image_cache[path] | 29 | return image_cache[path] |
@@ -28,10 +36,46 @@ def get_image(path): | |||
28 | return image | 36 | return image |
29 | 37 | ||
30 | 38 | ||
31 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 39 | def prepare_prompt(prompt: Union[str, dict[str, str]]): |
32 | return {"content": prompt} if isinstance(prompt, str) else prompt | 40 | return {"content": prompt} if isinstance(prompt, str) else prompt |
33 | 41 | ||
34 | 42 | ||
43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | ||
44 | item_order: list[int] = [] | ||
45 | item_buckets: list[int] = [] | ||
46 | buckets = [1.0] | ||
47 | |||
48 | for i in range(1, num_buckets + 1): | ||
49 | s = size + i * 64 | ||
50 | buckets.append(s / size) | ||
51 | buckets.append(size / s) | ||
52 | |||
53 | buckets = torch.tensor(buckets) | ||
54 | bucket_indices = torch.arange(len(buckets)) | ||
55 | |||
56 | for i, item in enumerate(items): | ||
57 | image = get_image(item) | ||
58 | ratio = image.width / image.height | ||
59 | |||
60 | if ratio >= 1: | ||
61 | mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) | ||
62 | else: | ||
63 | mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) | ||
64 | |||
65 | if not progressive_buckets: | ||
66 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | ||
67 | |||
68 | indices = bucket_indices[mask] | ||
69 | |||
70 | if len(indices.shape) == 0: | ||
71 | indices = indices.unsqueeze(0) | ||
72 | |||
73 | item_order += [i] * len(indices) | ||
74 | item_buckets += indices | ||
75 | |||
76 | return buckets.tolist(), item_order, item_buckets | ||
77 | |||
78 | |||
35 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
36 | instance_image_path: Path | 80 | instance_image_path: Path |
37 | class_image_path: Path | 81 | class_image_path: Path |
@@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple): | |||
41 | collection: list[str] | 85 | collection: list[str] |
42 | 86 | ||
43 | 87 | ||
44 | class VlpnDataBucket(): | ||
45 | def __init__(self, width: int, height: int): | ||
46 | self.width = width | ||
47 | self.height = height | ||
48 | self.ratio = width / height | ||
49 | self.items: list[VlpnDataItem] = [] | ||
50 | |||
51 | |||
52 | class VlpnDataModule(): | 88 | class VlpnDataModule(): |
53 | def __init__( | 89 | def __init__( |
54 | self, | 90 | self, |
@@ -60,7 +96,6 @@ class VlpnDataModule(): | |||
60 | size: int = 768, | 96 | size: int = 768, |
61 | num_aspect_ratio_buckets: int = 0, | 97 | num_aspect_ratio_buckets: int = 0, |
62 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_aspect_ratio_buckets: bool = False, |
63 | repeats: int = 1, | ||
64 | dropout: float = 0, | 99 | dropout: float = 0, |
65 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
66 | template_key: str = "template", | 101 | template_key: str = "template", |
@@ -86,7 +121,6 @@ class VlpnDataModule(): | |||
86 | self.size = size | 121 | self.size = size |
87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets |
88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets |
89 | self.repeats = repeats | ||
90 | self.dropout = dropout | 124 | self.dropout = dropout |
91 | self.template_key = template_key | 125 | self.template_key = template_key |
92 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
@@ -146,36 +180,6 @@ class VlpnDataModule(): | |||
146 | for i in range(image_multiplier) | 180 | for i in range(image_multiplier) |
147 | ] | 181 | ] |
148 | 182 | ||
149 | def generate_buckets(self, items: list[VlpnDataItem]): | ||
150 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
151 | |||
152 | for i in range(1, self.num_aspect_ratio_buckets + 1): | ||
153 | s = self.size + i * 64 | ||
154 | buckets.append(VlpnDataBucket(s, self.size)) | ||
155 | buckets.append(VlpnDataBucket(self.size, s)) | ||
156 | |||
157 | buckets = np.array(buckets) | ||
158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
159 | |||
160 | for item in items: | ||
161 | image = get_image(item.instance_image_path) | ||
162 | ratio = image.width / image.height | ||
163 | |||
164 | if ratio >= 1: | ||
165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) | ||
166 | else: | ||
167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) | ||
168 | |||
169 | if not self.progressive_aspect_ratio_buckets: | ||
170 | ratios = bucket_ratios.copy() | ||
171 | ratios[~mask] = math.inf | ||
172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
173 | |||
174 | for bucket in buckets[mask]: | ||
175 | bucket.items.append(item) | ||
176 | |||
177 | return [bucket for bucket in buckets if len(bucket.items) != 0] | ||
178 | |||
179 | def setup(self): | 183 | def setup(self): |
180 | with open(self.data_file, 'rt') as f: | 184 | with open(self.data_file, 'rt') as f: |
181 | metadata = json.load(f) | 185 | metadata = json.load(f) |
@@ -201,105 +205,136 @@ class VlpnDataModule(): | |||
201 | self.data_train = self.pad_items(data_train, self.num_class_images) | 205 | self.data_train = self.pad_items(data_train, self.num_class_images) |
202 | self.data_val = self.pad_items(data_val) | 206 | self.data_val = self.pad_items(data_val) |
203 | 207 | ||
204 | buckets = self.generate_buckets(data_train) | 208 | train_dataset = VlpnDataset( |
205 | 209 | self.data_train, self.prompt_processor, | |
206 | train_datasets = [ | 210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, |
207 | VlpnDataset( | 211 | batch_size=self.batch_size, |
208 | bucket.items, self.prompt_processor, | 212 | size=self.size, interpolation=self.interpolation, |
209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 214 | ) |
211 | ) | ||
212 | for bucket in buckets | ||
213 | ] | ||
214 | 215 | ||
215 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
216 | data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
217 | width=self.size, height=self.size, interpolation=self.interpolation, | 218 | batch_size=self.batch_size, |
219 | size=self.size, interpolation=self.interpolation, | ||
218 | ) | 220 | ) |
219 | 221 | ||
220 | self.train_dataloaders = [ | 222 | self.train_dataloader = DataLoader( |
221 | DataLoader( | 223 | train_dataset, |
222 | dataset, batch_size=self.batch_size, shuffle=True, | 224 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers |
223 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 225 | ) |
224 | ) | ||
225 | for dataset in train_datasets | ||
226 | ] | ||
227 | 226 | ||
228 | self.val_dataloader = DataLoader( | 227 | self.val_dataloader = DataLoader( |
229 | val_dataset, batch_size=self.batch_size, | 228 | val_dataset, |
230 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 229 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers |
231 | ) | 230 | ) |
232 | 231 | ||
233 | 232 | ||
234 | class VlpnDataset(Dataset): | 233 | class VlpnDataset(IterableDataset): |
235 | def __init__( | 234 | def __init__( |
236 | self, | 235 | self, |
237 | data: List[VlpnDataItem], | 236 | items: list[VlpnDataItem], |
238 | prompt_processor: PromptProcessor, | 237 | prompt_processor: PromptProcessor, |
238 | num_buckets: int = 1, | ||
239 | progressive_buckets: bool = False, | ||
240 | batch_size: int = 1, | ||
239 | num_class_images: int = 0, | 241 | num_class_images: int = 0, |
240 | width: int = 768, | 242 | size: int = 768, |
241 | height: int = 768, | ||
242 | repeats: int = 1, | ||
243 | dropout: float = 0, | 243 | dropout: float = 0, |
244 | shuffle: bool = False, | ||
244 | interpolation: str = "bicubic", | 245 | interpolation: str = "bicubic", |
246 | generator: Optional[torch.Generator] = None, | ||
245 | ): | 247 | ): |
248 | self.items = items | ||
249 | self.batch_size = batch_size | ||
246 | 250 | ||
247 | self.data = data | ||
248 | self.prompt_processor = prompt_processor | 251 | self.prompt_processor = prompt_processor |
249 | self.num_class_images = num_class_images | 252 | self.num_class_images = num_class_images |
253 | self.size = size | ||
250 | self.dropout = dropout | 254 | self.dropout = dropout |
251 | 255 | self.shuffle = shuffle | |
252 | self.num_instance_images = len(self.data) | 256 | self.interpolation = interpolations[interpolation] |
253 | self._length = self.num_instance_images * repeats | 257 | self.generator = generator |
254 | 258 | ||
255 | self.interpolation = { | 259 | buckets, item_order, item_buckets = generate_buckets( |
256 | "linear": transforms.InterpolationMode.NEAREST, | 260 | [item.instance_image_path for item in items], |
257 | "bilinear": transforms.InterpolationMode.BILINEAR, | 261 | size, |
258 | "bicubic": transforms.InterpolationMode.BICUBIC, | 262 | num_buckets, |
259 | "lanczos": transforms.InterpolationMode.LANCZOS, | 263 | progressive_buckets |
260 | }[interpolation] | ||
261 | self.image_transforms = transforms.Compose( | ||
262 | [ | ||
263 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
264 | transforms.RandomCrop((height, width)), | ||
265 | transforms.RandomHorizontalFlip(), | ||
266 | transforms.ToTensor(), | ||
267 | transforms.Normalize([0.5], [0.5]), | ||
268 | ] | ||
269 | ) | 264 | ) |
270 | 265 | ||
271 | def __len__(self): | 266 | self.buckets = torch.tensor(buckets) |
272 | return self._length | 267 | self.item_order = torch.tensor(item_order) |
268 | self.item_buckets = torch.tensor(item_buckets) | ||
273 | 269 | ||
274 | def get_example(self, i): | 270 | def __len__(self): |
275 | item = self.data[i % self.num_instance_images] | 271 | return len(self.item_buckets) |
276 | 272 | ||
277 | example = {} | 273 | def __iter__(self): |
278 | example["prompts"] = item.prompt | 274 | worker_info = torch.utils.data.get_worker_info() |
279 | example["cprompts"] = item.cprompt | 275 | |
280 | example["nprompts"] = item.nprompt | 276 | if self.shuffle: |
281 | example["instance_images"] = get_image(item.instance_image_path) | 277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) |
282 | if self.num_class_images != 0: | 278 | self.item_order = self.item_order[perm] |
283 | example["class_images"] = get_image(item.class_image_path) | 279 | self.item_buckets = self.item_buckets[perm] |
284 | 280 | ||
285 | return example | 281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) |
282 | bucket = -1 | ||
283 | image_transforms = None | ||
284 | batch = [] | ||
285 | batch_size = self.batch_size | ||
286 | |||
287 | if worker_info is not None: | ||
288 | batch_size = math.ceil(batch_size / worker_info.num_workers) | ||
289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | ||
290 | start = worker_info.id * worker_batch | ||
291 | end = start + worker_batch | ||
292 | item_mask[:start] = False | ||
293 | item_mask[end:] = False | ||
294 | |||
295 | while item_mask.any(): | ||
296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | ||
297 | |||
298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | ||
299 | yield batch | ||
300 | batch = [] | ||
301 | |||
302 | if len(item_indices) == 0: | ||
303 | bucket = self.item_buckets[item_mask][0] | ||
304 | ratio = self.buckets[bucket] | ||
305 | width = self.size * ratio if ratio > 1 else self.size | ||
306 | height = self.size / ratio if ratio < 1 else self.size | ||
307 | |||
308 | image_transforms = transforms.Compose( | ||
309 | [ | ||
310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
311 | transforms.RandomCrop((height, width)), | ||
312 | transforms.RandomHorizontalFlip(), | ||
313 | transforms.ToTensor(), | ||
314 | transforms.Normalize([0.5], [0.5]), | ||
315 | ] | ||
316 | ) | ||
317 | else: | ||
318 | item_index = item_indices[0] | ||
319 | item = self.items[item_index] | ||
320 | item_mask[item_index] = False | ||
286 | 321 | ||
287 | def __getitem__(self, i): | 322 | example = {} |
288 | unprocessed_example = self.get_example(i) | ||
289 | 323 | ||
290 | example = {} | 324 | example["prompts"] = keywords_to_prompt(item.prompt) |
325 | example["cprompts"] = item.cprompt | ||
326 | example["nprompts"] = item.nprompt | ||
291 | 327 | ||
292 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) | 328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
293 | example["cprompts"] = unprocessed_example["cprompts"] | 329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
294 | example["nprompts"] = unprocessed_example["nprompts"] | 330 | keywords_to_prompt(item.prompt, self.dropout, True) |
331 | ) | ||
295 | 332 | ||
296 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 333 | if self.num_class_images != 0: |
297 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
298 | keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
299 | ) | ||
300 | 336 | ||
301 | if self.num_class_images != 0: | 337 | batch.append(example) |
302 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | ||
303 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | ||
304 | 338 | ||
305 | return example | 339 | if len(batch) != 0: |
340 | yield batch | ||