summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py183
-rw-r--r--data/keywords.py8
2 files changed, 131 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py
index 14380e8..d726033 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -49,7 +49,7 @@ def generate_buckets(
49 max_pixels: Optional[int] = None, 49 max_pixels: Optional[int] = None,
50 num_buckets: int = 4, 50 num_buckets: int = 4,
51 progressive_buckets: bool = False, 51 progressive_buckets: bool = False,
52 return_tensor: bool = True 52 return_tensor: bool = True,
53): 53):
54 if max_pixels is None: 54 if max_pixels is None:
55 max_pixels = (base_size + step_size) ** 2 55 max_pixels = (base_size + step_size) ** 2
@@ -62,7 +62,11 @@ def generate_buckets(
62 62
63 for i in range(1, num_buckets + 1): 63 for i in range(1, num_buckets + 1):
64 long_side = base_size + i * step_size 64 long_side = base_size + i * step_size
65 short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) 65 short_side = min(
66 base_size
67 - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size,
68 base_size,
69 )
66 buckets.append(long_side / short_side) 70 buckets.append(long_side / short_side)
67 buckets.append(short_side / long_side) 71 buckets.append(short_side / long_side)
68 72
@@ -106,7 +110,7 @@ def collate_fn(
106 max_token_id_length: Optional[int], 110 max_token_id_length: Optional[int],
107 with_guidance: bool, 111 with_guidance: bool,
108 with_prior_preservation: bool, 112 with_prior_preservation: bool,
109 examples 113 examples,
110): 114):
111 prompt_ids = [example["prompt_ids"] for example in examples] 115 prompt_ids = [example["prompt_ids"] for example in examples]
112 nprompt_ids = [example["nprompt_ids"] for example in examples] 116 nprompt_ids = [example["nprompt_ids"] for example in examples]
@@ -125,7 +129,9 @@ def collate_fn(
125 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) 129 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length)
126 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) 130 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length)
127 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) 131 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length)
128 negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) 132 negative_inputs = unify_input_ids(
133 tokenizer, negative_input_ids, max_token_id_length
134 )
129 135
130 batch = { 136 batch = {
131 "prompt_ids": prompts.input_ids, 137 "prompt_ids": prompts.input_ids,
@@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple):
149 nprompt: str 155 nprompt: str
150 collection: list[str] 156 collection: list[str]
151 157
152 def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): 158 def full_prompt(
153 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) 159 self,
160 dropout: float = 0,
161 shuffle: bool = False,
162 npgenerator: Optional[np.random.Generator] = None,
163 ):
164 return keywords_to_str(
165 self.keywords, [self.prompt], dropout, shuffle, npgenerator
166 )
154 167
155 168
156def keyword_filter( 169def keyword_filter(
157 placeholder_tokens: Optional[list[str]], 170 placeholder_tokens: Optional[list[str]],
158 collections: Optional[list[str]], 171 collections: Optional[list[str]],
159 exclude_collections: Optional[list[str]], 172 exclude_collections: Optional[list[str]],
160 item: VlpnDataItem 173 item: VlpnDataItem,
161): 174):
162 full_prompt = item.full_prompt() 175 full_prompt = item.full_prompt()
163 176
164 cond1 = placeholder_tokens is None or any( 177 cond1 = placeholder_tokens is None or any(
165 token in full_prompt 178 token in full_prompt for token in placeholder_tokens
166 for token in placeholder_tokens
167 ) 179 )
168 cond2 = collections is None or any( 180 cond2 = collections is None or any(
169 collection in item.collection 181 collection in item.collection for collection in collections
170 for collection in collections
171 ) 182 )
172 cond3 = exclude_collections is None or not any( 183 cond3 = exclude_collections is None or not any(
173 collection in item.collection 184 collection in item.collection for collection in exclude_collections
174 for collection in exclude_collections
175 ) 185 )
176 186
177 return cond1 and cond2 and cond3 187 return cond1 and cond2 and cond3
178 188
179 189
180class VlpnDataModule(): 190class VlpnDataModule:
181 def __init__( 191 def __init__(
182 self, 192 self,
183 batch_size: int, 193 batch_size: int,
@@ -222,7 +232,7 @@ class VlpnDataModule():
222 232
223 self.constant_prompt_length = constant_prompt_length 233 self.constant_prompt_length = constant_prompt_length
224 self.max_token_id_length = None 234 self.max_token_id_length = None
225 235
226 self.tokenizer = tokenizer 236 self.tokenizer = tokenizer
227 self.size = size 237 self.size = size
228 self.num_buckets = num_buckets 238 self.num_buckets = num_buckets
@@ -259,23 +269,29 @@ class VlpnDataModule():
259 nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") 269 nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "")
260 collection = item["collection"].split(", ") if "collection" in item else [] 270 collection = item["collection"].split(", ") if "collection" in item else []
261 271
262 saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) 272 saturated_keywords = str_to_keywords(
273 tpl_keywords.format(**keywords), expansions
274 )
263 275
264 inverted_tokens = keywords_to_str([ 276 inverted_tokens = keywords_to_str(
265 f"inv_{token}" 277 [
266 for token in self.placeholder_tokens 278 f"inv_{token}"
267 if token in saturated_keywords 279 for token in self.placeholder_tokens
268 ]) 280 if token in saturated_keywords
281 ]
282 )
269 283
270 items.append(VlpnDataItem( 284 items.append(
271 self.data_root / image, 285 VlpnDataItem(
272 None, 286 self.data_root / image,
273 saturated_keywords, 287 None,
274 tpl_prompt.format(**prompt), 288 saturated_keywords,
275 tpl_cprompt.format(**prompt), 289 tpl_prompt.format(**prompt),
276 tpl_nprompt.format(_inv=inverted_tokens, **nprompt), 290 tpl_cprompt.format(**prompt),
277 collection 291 tpl_nprompt.format(_inv=inverted_tokens, **nprompt),
278 )) 292 collection,
293 )
294 )
279 295
280 return items 296 return items
281 297
@@ -285,13 +301,16 @@ class VlpnDataModule():
285 301
286 return [item for item in items if self.filter(item)] 302 return [item for item in items if self.filter(item)]
287 303
288 def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: 304 def pad_items(
305 self, items: list[VlpnDataItem], num_class_images: int = 1
306 ) -> list[VlpnDataItem]:
289 image_multiplier = max(num_class_images, 1) 307 image_multiplier = max(num_class_images, 1)
290 308
291 return [ 309 return [
292 VlpnDataItem( 310 VlpnDataItem(
293 item.instance_image_path, 311 item.instance_image_path,
294 self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", 312 self.class_root
313 / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}",
295 item.keywords, 314 item.keywords,
296 item.prompt, 315 item.prompt,
297 item.cprompt, 316 item.cprompt,
@@ -303,7 +322,7 @@ class VlpnDataModule():
303 ] 322 ]
304 323
305 def setup(self): 324 def setup(self):
306 with open(self.data_file, 'rt') as f: 325 with open(self.data_file, "rt") as f:
307 metadata = json.load(f) 326 metadata = json.load(f)
308 template = metadata[self.template_key] if self.template_key in metadata else {} 327 template = metadata[self.template_key] if self.template_key in metadata else {}
309 expansions = metadata["expansions"] if "expansions" in metadata else {} 328 expansions = metadata["expansions"] if "expansions" in metadata else {}
@@ -312,25 +331,41 @@ class VlpnDataModule():
312 items = self.prepare_items(template, expansions, items) 331 items = self.prepare_items(template, expansions, items)
313 items = self.filter_items(items) 332 items = self.filter_items(items)
314 self.npgenerator.shuffle(items) 333 self.npgenerator.shuffle(items)
315 334
316 if self.constant_prompt_length: 335 if self.constant_prompt_length:
317 all_input_ids = unify_input_ids( 336 all_input_ids = unify_input_ids(
318 self.tokenizer, 337 self.tokenizer,
319 [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] 338 [
339 self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids
340 for item in items
341 ],
320 ).input_ids 342 ).input_ids
321 self.max_token_id_length = all_input_ids.shape[1] 343 self.max_token_id_length = all_input_ids.shape[1]
322 344
323 num_images = len(items) 345 num_images = len(items)
324 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 346 valid_set_size = (
347 min(self.valid_set_size, num_images)
348 if self.valid_set_size is not None
349 else num_images // 10
350 )
325 train_set_size = max(num_images - valid_set_size, 1) 351 train_set_size = max(num_images - valid_set_size, 1)
326 valid_set_size = num_images - train_set_size 352 valid_set_size = num_images - train_set_size
327 353
328 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) 354 collate_fn_ = partial(
355 collate_fn,
356 self.dtype,
357 self.tokenizer,
358 self.max_token_id_length,
359 self.with_guidance,
360 self.num_class_images != 0,
361 )
329 362
330 if valid_set_size == 0: 363 if valid_set_size == 0:
331 data_train, data_val = items, items 364 data_train, data_val = items, items
332 else: 365 else:
333 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) 366 data_train, data_val = random_split(
367 items, [train_set_size, valid_set_size], generator=self.generator
368 )
334 369
335 data_train = self.pad_items(data_train, self.num_class_images) 370 data_train = self.pad_items(data_train, self.num_class_images)
336 371
@@ -338,17 +373,25 @@ class VlpnDataModule():
338 data_train *= math.ceil(self.train_set_pad / len(data_train)) 373 data_train *= math.ceil(self.train_set_pad / len(data_train))
339 374
340 self.train_dataset = VlpnDataset( 375 self.train_dataset = VlpnDataset(
341 data_train, self.tokenizer, 376 data_train,
342 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 377 self.tokenizer,
343 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 378 num_buckets=self.num_buckets,
344 batch_size=self.batch_size, fill_batch=True, generator=self.generator, 379 progressive_buckets=self.progressive_buckets,
345 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 380 bucket_step_size=self.bucket_step_size,
346 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 381 bucket_max_pixels=self.bucket_max_pixels,
382 batch_size=self.batch_size,
383 fill_batch=True,
384 generator=self.generator,
385 size=self.size,
386 interpolation=self.interpolation,
387 color_jitter=self.color_jitter,
388 num_class_images=self.num_class_images,
389 dropout=self.dropout,
390 shuffle=self.shuffle,
347 ) 391 )
348 392
349 self.train_dataloader = DataLoader( 393 self.train_dataloader = DataLoader(
350 self.train_dataset, 394 self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_
351 batch_size=None, pin_memory=True, collate_fn=collate_fn_
352 ) 395 )
353 396
354 if len(data_val) != 0: 397 if len(data_val) != 0:
@@ -358,16 +401,24 @@ class VlpnDataModule():
358 data_val *= math.ceil(self.valid_set_pad / len(data_val)) 401 data_val *= math.ceil(self.valid_set_pad / len(data_val))
359 402
360 self.val_dataset = VlpnDataset( 403 self.val_dataset = VlpnDataset(
361 data_val, self.tokenizer, 404 data_val,
362 num_buckets=self.num_buckets, progressive_buckets=True, 405 self.tokenizer,
363 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 406 num_buckets=self.num_buckets,
364 batch_size=self.batch_size, generator=self.generator, 407 progressive_buckets=True,
365 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 408 bucket_step_size=self.bucket_step_size,
409 bucket_max_pixels=self.bucket_max_pixels,
410 batch_size=self.batch_size,
411 generator=self.generator,
412 size=self.size,
413 interpolation=self.interpolation,
414 color_jitter=self.color_jitter,
366 ) 415 )
367 416
368 self.val_dataloader = DataLoader( 417 self.val_dataloader = DataLoader(
369 self.val_dataset, 418 self.val_dataset,
370 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 419 batch_size=None,
420 pin_memory=True,
421 collate_fn=collate_fn_,
371 ) 422 )
372 else: 423 else:
373 self.val_dataloader = None 424 self.val_dataloader = None
@@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset):
418 469
419 self.bucket_item_range = torch.arange(len(self.bucket_items)) 470 self.bucket_item_range = torch.arange(len(self.bucket_items))
420 471
421 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 472 self.length_ = (
473 (self.bucket_assignments.bincount() / self.batch_size)
474 .ceil()
475 .long()
476 .sum()
477 .item()
478 )
422 479
423 def get_input_ids(self, text: str): 480 def get_input_ids(self, text: str):
424 return self.tokenizer(text, padding="do_not_pad").input_ids 481 return self.tokenizer(text, padding="do_not_pad").input_ids
@@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset):
430 worker_info = torch.utils.data.get_worker_info() 487 worker_info = torch.utils.data.get_worker_info()
431 488
432 if self.shuffle: 489 if self.shuffle:
433 perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) 490 perm = torch.randperm(
491 len(self.bucket_assignments), generator=self.generator
492 )
434 self.bucket_items = self.bucket_items[perm] 493 self.bucket_items = self.bucket_items[perm]
435 self.bucket_assignments = self.bucket_assignments[perm] 494 self.bucket_assignments = self.bucket_assignments[perm]
436 495
@@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset):
488 547
489 if len(bucket_items) == 0: 548 if len(bucket_items) == 0:
490 bucket_items = self.bucket_items[self.bucket_assignments == bucket] 549 bucket_items = self.bucket_items[self.bucket_assignments == bucket]
491 item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] 550 item_index = bucket_items[
551 torch.randint(len(bucket_items), (1,), generator=self.generator)
552 ]
492 else: 553 else:
493 item_index = bucket_items[0] 554 item_index = bucket_items[0]
494 mask[self.bucket_item_range[bucket_mask][0]] = False 555 mask[self.bucket_item_range[bucket_mask][0]] = False
@@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset):
500 example["prompt_ids"] = self.get_input_ids(item.full_prompt()) 561 example["prompt_ids"] = self.get_input_ids(item.full_prompt())
501 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 562 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
502 563
503 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) 564 example["instance_prompt_ids"] = self.get_input_ids(
565 item.full_prompt(self.dropout, True, self.npgenerator)
566 )
504 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 567 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
505 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 568 example["instance_images"] = image_transforms(
569 get_image(item.instance_image_path)
570 )
506 571
507 if self.num_class_images != 0: 572 if self.num_class_images != 0:
508 example["class_prompt_ids"] = self.get_input_ids(item.cprompt) 573 example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
509 example["class_images"] = image_transforms(get_image(item.class_image_path)) 574 example["class_images"] = image_transforms(
575 get_image(item.class_image_path)
576 )
510 577
511 batch.append(example) 578 batch.append(example)
diff --git a/data/keywords.py b/data/keywords.py
index 8632d67..83fe9ff 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -8,7 +8,7 @@ def keywords_to_str(
8 undroppable_keywords: list[str] = [], 8 undroppable_keywords: list[str] = [],
9 dropout: float = 0, 9 dropout: float = 0,
10 shuffle: bool = False, 10 shuffle: bool = False,
11 npgenerator: Optional[np.random.Generator] = None 11 npgenerator: Optional[np.random.Generator] = None,
12) -> str: 12) -> str:
13 if dropout != 0: 13 if dropout != 0:
14 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 14 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
@@ -23,7 +23,11 @@ def keywords_to_str(
23 23
24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: 24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
25 def expand_keyword(keyword: str) -> list[str]: 25 def expand_keyword(keyword: str) -> list[str]:
26 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] 26 return (
27 [keyword] + expansions[keyword].split(", ")
28 if keyword in expansions
29 else [keyword]
30 )
27 31
28 return [ 32 return [
29 kw 33 kw