summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py78
1 files changed, 43 insertions, 35 deletions
diff --git a/data/csv.py b/data/csv.py
index 9be36ba..289a64d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]):
41 41
42 42
43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): 43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool):
44 item_order: list[int] = [] 44 bucket_items: list[int] = []
45 item_buckets: list[int] = [] 45 bucket_assignments: list[int] = []
46 buckets = [1.0] 46 buckets = [1.0]
47 47
48 for i in range(1, num_buckets + 1): 48 for i in range(1, num_buckets + 1):
@@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_
70 if len(indices.shape) == 0: 70 if len(indices.shape) == 0:
71 indices = indices.unsqueeze(0) 71 indices = indices.unsqueeze(0)
72 72
73 item_order += [i] * len(indices) 73 bucket_items += [i] * len(indices)
74 item_buckets += indices 74 bucket_assignments += indices
75 75
76 return buckets.tolist(), item_order, item_buckets 76 return buckets.tolist(), bucket_items, bucket_assignments
77 77
78 78
79class VlpnDataItem(NamedTuple): 79class VlpnDataItem(NamedTuple):
@@ -94,8 +94,8 @@ class VlpnDataModule():
94 class_subdir: str = "cls", 94 class_subdir: str = "cls",
95 num_class_images: int = 1, 95 num_class_images: int = 1,
96 size: int = 768, 96 size: int = 768,
97 num_aspect_ratio_buckets: int = 0, 97 num_buckets: int = 0,
98 progressive_aspect_ratio_buckets: bool = False, 98 progressive_buckets: bool = False,
99 dropout: float = 0, 99 dropout: float = 0,
100 interpolation: str = "bicubic", 100 interpolation: str = "bicubic",
101 template_key: str = "template", 101 template_key: str = "template",
@@ -119,8 +119,8 @@ class VlpnDataModule():
119 119
120 self.prompt_processor = prompt_processor 120 self.prompt_processor = prompt_processor
121 self.size = size 121 self.size = size
122 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets 122 self.num_buckets = num_buckets
123 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets 123 self.progressive_buckets = progressive_buckets
124 self.dropout = dropout 124 self.dropout = dropout
125 self.template_key = template_key 125 self.template_key = template_key
126 self.interpolation = interpolation 126 self.interpolation = interpolation
@@ -207,15 +207,15 @@ class VlpnDataModule():
207 207
208 train_dataset = VlpnDataset( 208 train_dataset = VlpnDataset(
209 self.data_train, self.prompt_processor, 209 self.data_train, self.prompt_processor,
210 num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, 210 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
211 batch_size=self.batch_size, 211 batch_size=self.batch_size, generator=generator,
212 size=self.size, interpolation=self.interpolation, 212 size=self.size, interpolation=self.interpolation,
213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
214 ) 214 )
215 215
216 val_dataset = VlpnDataset( 216 val_dataset = VlpnDataset(
217 self.data_val, self.prompt_processor, 217 self.data_val, self.prompt_processor,
218 batch_size=self.batch_size, 218 batch_size=self.batch_size, generator=generator,
219 size=self.size, interpolation=self.interpolation, 219 size=self.size, interpolation=self.interpolation,
220 ) 220 )
221 221
@@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset):
256 self.interpolation = interpolations[interpolation] 256 self.interpolation = interpolations[interpolation]
257 self.generator = generator 257 self.generator = generator
258 258
259 buckets, item_order, item_buckets = generate_buckets( 259 buckets, bucket_items, bucket_assignments = generate_buckets(
260 [item.instance_image_path for item in items], 260 [item.instance_image_path for item in items],
261 size, 261 size,
262 num_buckets, 262 num_buckets,
@@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset):
264 ) 264 )
265 265
266 self.buckets = torch.tensor(buckets) 266 self.buckets = torch.tensor(buckets)
267 self.item_order = torch.tensor(item_order) 267 self.bucket_items = torch.tensor(bucket_items)
268 self.item_buckets = torch.tensor(item_buckets) 268 self.bucket_assignments = torch.tensor(bucket_assignments)
269 self.bucket_item_range = torch.arange(len(bucket_items))
270
271 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
269 272
270 def __len__(self): 273 def __len__(self):
271 return len(self.item_buckets) 274 return self.length_
272 275
273 def __iter__(self): 276 def __iter__(self):
274 worker_info = torch.utils.data.get_worker_info() 277 worker_info = torch.utils.data.get_worker_info()
275 278
276 if self.shuffle: 279 if self.shuffle:
277 perm = torch.randperm(len(self.item_buckets), generator=self.generator) 280 perm = torch.randperm(len(self.bucket_assignments), generator=self.generator)
278 self.item_order = self.item_order[perm] 281 self.bucket_items = self.bucket_items[perm]
279 self.item_buckets = self.item_buckets[perm] 282 self.bucket_assignments = self.bucket_assignments[perm]
280 283
281 item_mask = torch.ones_like(self.item_buckets, dtype=bool)
282 bucket = -1
283 image_transforms = None 284 image_transforms = None
285
286 mask = torch.ones_like(self.bucket_assignments, dtype=bool)
287 bucket = -1
284 batch = [] 288 batch = []
285 batch_size = self.batch_size 289 batch_size = self.batch_size
286 290
@@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset):
289 worker_batch = math.ceil(len(self) / worker_info.num_workers) 293 worker_batch = math.ceil(len(self) / worker_info.num_workers)
290 start = worker_info.id * worker_batch 294 start = worker_info.id * worker_batch
291 end = start + worker_batch 295 end = start + worker_batch
292 item_mask[:start] = False 296 mask[:start] = False
293 item_mask[end:] = False 297 mask[end:] = False
294 298
295 while item_mask.any(): 299 while mask.any():
296 item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] 300 bucket_mask = mask.logical_and(self.bucket_assignments == bucket)
301 bucket_items = self.bucket_items[bucket_mask]
297 302
298 if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): 303 if len(batch) >= batch_size:
299 yield batch 304 yield batch
300 batch = [] 305 batch = []
301 306
302 if len(item_indices) == 0: 307 if len(bucket_items) == 0:
303 bucket = self.item_buckets[item_mask][0] 308 if len(batch) != 0:
309 yield batch
310 batch = []
311
312 bucket = self.bucket_assignments[mask][0]
304 ratio = self.buckets[bucket] 313 ratio = self.buckets[bucket]
305 width = self.size * ratio if ratio > 1 else self.size 314 width = self.size * ratio if ratio > 1 else self.size
306 height = self.size / ratio if ratio < 1 else self.size 315 height = self.size / ratio if ratio < 1 else self.size
307 316
308 image_transforms = transforms.Compose( 317 image_transforms = transforms.Compose(
309 [ 318 [
310 transforms.Resize(min(width, height), interpolation=self.interpolation), 319 transforms.Resize(self.size, interpolation=self.interpolation),
311 transforms.RandomCrop((height, width)), 320 transforms.RandomCrop((height, width)),
312 transforms.RandomHorizontalFlip(), 321 transforms.RandomHorizontalFlip(),
313 transforms.ToTensor(), 322 transforms.ToTensor(),
@@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset):
315 ] 324 ]
316 ) 325 )
317 else: 326 else:
318 item_index = item_indices[0] 327 item_index = bucket_items[0]
319 item = self.items[item_index] 328 item = self.items[item_index]
320 item_mask[item_index] = False 329 mask[self.bucket_item_range[bucket_mask][0]] = False
321 330
322 example = {} 331 example = {}
323 332
324 example["prompts"] = keywords_to_prompt(item.prompt) 333 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
325 example["cprompts"] = item.cprompt 334 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
326 example["nprompts"] = item.nprompt
327 335
328 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 336 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
329 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 337 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
@@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset):
332 340
333 if self.num_class_images != 0: 341 if self.num_class_images != 0:
334 example["class_images"] = image_transforms(get_image(item.class_image_path)) 342 example["class_images"] = image_transforms(get_image(item.class_image_path))
335 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) 343 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
336 344
337 batch.append(example) 345 batch.append(example)
338 346