diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 183 | ||||
-rw-r--r-- | data/keywords.py | 8 |
2 files changed, 131 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -49,7 +49,7 @@ def generate_buckets( | |||
49 | max_pixels: Optional[int] = None, | 49 | max_pixels: Optional[int] = None, |
50 | num_buckets: int = 4, | 50 | num_buckets: int = 4, |
51 | progressive_buckets: bool = False, | 51 | progressive_buckets: bool = False, |
52 | return_tensor: bool = True | 52 | return_tensor: bool = True, |
53 | ): | 53 | ): |
54 | if max_pixels is None: | 54 | if max_pixels is None: |
55 | max_pixels = (base_size + step_size) ** 2 | 55 | max_pixels = (base_size + step_size) ** 2 |
@@ -62,7 +62,11 @@ def generate_buckets( | |||
62 | 62 | ||
63 | for i in range(1, num_buckets + 1): | 63 | for i in range(1, num_buckets + 1): |
64 | long_side = base_size + i * step_size | 64 | long_side = base_size + i * step_size |
65 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) | 65 | short_side = min( |
66 | base_size | ||
67 | - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, | ||
68 | base_size, | ||
69 | ) | ||
66 | buckets.append(long_side / short_side) | 70 | buckets.append(long_side / short_side) |
67 | buckets.append(short_side / long_side) | 71 | buckets.append(short_side / long_side) |
68 | 72 | ||
@@ -106,7 +110,7 @@ def collate_fn( | |||
106 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
107 | with_guidance: bool, | 111 | with_guidance: bool, |
108 | with_prior_preservation: bool, | 112 | with_prior_preservation: bool, |
109 | examples | 113 | examples, |
110 | ): | 114 | ): |
111 | prompt_ids = [example["prompt_ids"] for example in examples] | 115 | prompt_ids = [example["prompt_ids"] for example in examples] |
112 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 116 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
@@ -125,7 +129,9 @@ def collate_fn( | |||
125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) | 129 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) | 130 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) | 131 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) | 132 | negative_inputs = unify_input_ids( |
133 | tokenizer, negative_input_ids, max_token_id_length | ||
134 | ) | ||
129 | 135 | ||
130 | batch = { | 136 | batch = { |
131 | "prompt_ids": prompts.input_ids, | 137 | "prompt_ids": prompts.input_ids, |
@@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): | |||
149 | nprompt: str | 155 | nprompt: str |
150 | collection: list[str] | 156 | collection: list[str] |
151 | 157 | ||
152 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): | 158 | def full_prompt( |
153 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) | 159 | self, |
160 | dropout: float = 0, | ||
161 | shuffle: bool = False, | ||
162 | npgenerator: Optional[np.random.Generator] = None, | ||
163 | ): | ||
164 | return keywords_to_str( | ||
165 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | ||
166 | ) | ||
154 | 167 | ||
155 | 168 | ||
156 | def keyword_filter( | 169 | def keyword_filter( |
157 | placeholder_tokens: Optional[list[str]], | 170 | placeholder_tokens: Optional[list[str]], |
158 | collections: Optional[list[str]], | 171 | collections: Optional[list[str]], |
159 | exclude_collections: Optional[list[str]], | 172 | exclude_collections: Optional[list[str]], |
160 | item: VlpnDataItem | 173 | item: VlpnDataItem, |
161 | ): | 174 | ): |
162 | full_prompt = item.full_prompt() | 175 | full_prompt = item.full_prompt() |
163 | 176 | ||
164 | cond1 = placeholder_tokens is None or any( | 177 | cond1 = placeholder_tokens is None or any( |
165 | token in full_prompt | 178 | token in full_prompt for token in placeholder_tokens |
166 | for token in placeholder_tokens | ||
167 | ) | 179 | ) |
168 | cond2 = collections is None or any( | 180 | cond2 = collections is None or any( |
169 | collection in item.collection | 181 | collection in item.collection for collection in collections |
170 | for collection in collections | ||
171 | ) | 182 | ) |
172 | cond3 = exclude_collections is None or not any( | 183 | cond3 = exclude_collections is None or not any( |
173 | collection in item.collection | 184 | collection in item.collection for collection in exclude_collections |
174 | for collection in exclude_collections | ||
175 | ) | 185 | ) |
176 | 186 | ||
177 | return cond1 and cond2 and cond3 | 187 | return cond1 and cond2 and cond3 |
178 | 188 | ||
179 | 189 | ||
180 | class VlpnDataModule(): | 190 | class VlpnDataModule: |
181 | def __init__( | 191 | def __init__( |
182 | self, | 192 | self, |
183 | batch_size: int, | 193 | batch_size: int, |
@@ -222,7 +232,7 @@ class VlpnDataModule(): | |||
222 | 232 | ||
223 | self.constant_prompt_length = constant_prompt_length | 233 | self.constant_prompt_length = constant_prompt_length |
224 | self.max_token_id_length = None | 234 | self.max_token_id_length = None |
225 | 235 | ||
226 | self.tokenizer = tokenizer | 236 | self.tokenizer = tokenizer |
227 | self.size = size | 237 | self.size = size |
228 | self.num_buckets = num_buckets | 238 | self.num_buckets = num_buckets |
@@ -259,23 +269,29 @@ class VlpnDataModule(): | |||
259 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | 269 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") |
260 | collection = item["collection"].split(", ") if "collection" in item else [] | 270 | collection = item["collection"].split(", ") if "collection" in item else [] |
261 | 271 | ||
262 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) | 272 | saturated_keywords = str_to_keywords( |
273 | tpl_keywords.format(**keywords), expansions | ||
274 | ) | ||
263 | 275 | ||
264 | inverted_tokens = keywords_to_str([ | 276 | inverted_tokens = keywords_to_str( |
265 | f"inv_{token}" | 277 | [ |
266 | for token in self.placeholder_tokens | 278 | f"inv_{token}" |
267 | if token in saturated_keywords | 279 | for token in self.placeholder_tokens |
268 | ]) | 280 | if token in saturated_keywords |
281 | ] | ||
282 | ) | ||
269 | 283 | ||
270 | items.append(VlpnDataItem( | 284 | items.append( |
271 | self.data_root / image, | 285 | VlpnDataItem( |
272 | None, | 286 | self.data_root / image, |
273 | saturated_keywords, | 287 | None, |
274 | tpl_prompt.format(**prompt), | 288 | saturated_keywords, |
275 | tpl_cprompt.format(**prompt), | 289 | tpl_prompt.format(**prompt), |
276 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | 290 | tpl_cprompt.format(**prompt), |
277 | collection | 291 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), |
278 | )) | 292 | collection, |
293 | ) | ||
294 | ) | ||
279 | 295 | ||
280 | return items | 296 | return items |
281 | 297 | ||
@@ -285,13 +301,16 @@ class VlpnDataModule(): | |||
285 | 301 | ||
286 | return [item for item in items if self.filter(item)] | 302 | return [item for item in items if self.filter(item)] |
287 | 303 | ||
288 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: | 304 | def pad_items( |
305 | self, items: list[VlpnDataItem], num_class_images: int = 1 | ||
306 | ) -> list[VlpnDataItem]: | ||
289 | image_multiplier = max(num_class_images, 1) | 307 | image_multiplier = max(num_class_images, 1) |
290 | 308 | ||
291 | return [ | 309 | return [ |
292 | VlpnDataItem( | 310 | VlpnDataItem( |
293 | item.instance_image_path, | 311 | item.instance_image_path, |
294 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 312 | self.class_root |
313 | / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | ||
295 | item.keywords, | 314 | item.keywords, |
296 | item.prompt, | 315 | item.prompt, |
297 | item.cprompt, | 316 | item.cprompt, |
@@ -303,7 +322,7 @@ class VlpnDataModule(): | |||
303 | ] | 322 | ] |
304 | 323 | ||
305 | def setup(self): | 324 | def setup(self): |
306 | with open(self.data_file, 'rt') as f: | 325 | with open(self.data_file, "rt") as f: |
307 | metadata = json.load(f) | 326 | metadata = json.load(f) |
308 | template = metadata[self.template_key] if self.template_key in metadata else {} | 327 | template = metadata[self.template_key] if self.template_key in metadata else {} |
309 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 328 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
@@ -312,25 +331,41 @@ class VlpnDataModule(): | |||
312 | items = self.prepare_items(template, expansions, items) | 331 | items = self.prepare_items(template, expansions, items) |
313 | items = self.filter_items(items) | 332 | items = self.filter_items(items) |
314 | self.npgenerator.shuffle(items) | 333 | self.npgenerator.shuffle(items) |
315 | 334 | ||
316 | if self.constant_prompt_length: | 335 | if self.constant_prompt_length: |
317 | all_input_ids = unify_input_ids( | 336 | all_input_ids = unify_input_ids( |
318 | self.tokenizer, | 337 | self.tokenizer, |
319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | 338 | [ |
339 | self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids | ||
340 | for item in items | ||
341 | ], | ||
320 | ).input_ids | 342 | ).input_ids |
321 | self.max_token_id_length = all_input_ids.shape[1] | 343 | self.max_token_id_length = all_input_ids.shape[1] |
322 | 344 | ||
323 | num_images = len(items) | 345 | num_images = len(items) |
324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 346 | valid_set_size = ( |
347 | min(self.valid_set_size, num_images) | ||
348 | if self.valid_set_size is not None | ||
349 | else num_images // 10 | ||
350 | ) | ||
325 | train_set_size = max(num_images - valid_set_size, 1) | 351 | train_set_size = max(num_images - valid_set_size, 1) |
326 | valid_set_size = num_images - train_set_size | 352 | valid_set_size = num_images - train_set_size |
327 | 353 | ||
328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) | 354 | collate_fn_ = partial( |
355 | collate_fn, | ||
356 | self.dtype, | ||
357 | self.tokenizer, | ||
358 | self.max_token_id_length, | ||
359 | self.with_guidance, | ||
360 | self.num_class_images != 0, | ||
361 | ) | ||
329 | 362 | ||
330 | if valid_set_size == 0: | 363 | if valid_set_size == 0: |
331 | data_train, data_val = items, items | 364 | data_train, data_val = items, items |
332 | else: | 365 | else: |
333 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) | 366 | data_train, data_val = random_split( |
367 | items, [train_set_size, valid_set_size], generator=self.generator | ||
368 | ) | ||
334 | 369 | ||
335 | data_train = self.pad_items(data_train, self.num_class_images) | 370 | data_train = self.pad_items(data_train, self.num_class_images) |
336 | 371 | ||
@@ -338,17 +373,25 @@ class VlpnDataModule(): | |||
338 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | 373 | data_train *= math.ceil(self.train_set_pad / len(data_train)) |
339 | 374 | ||
340 | self.train_dataset = VlpnDataset( | 375 | self.train_dataset = VlpnDataset( |
341 | data_train, self.tokenizer, | 376 | data_train, |
342 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 377 | self.tokenizer, |
343 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 378 | num_buckets=self.num_buckets, |
344 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, | 379 | progressive_buckets=self.progressive_buckets, |
345 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 380 | bucket_step_size=self.bucket_step_size, |
346 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 381 | bucket_max_pixels=self.bucket_max_pixels, |
382 | batch_size=self.batch_size, | ||
383 | fill_batch=True, | ||
384 | generator=self.generator, | ||
385 | size=self.size, | ||
386 | interpolation=self.interpolation, | ||
387 | color_jitter=self.color_jitter, | ||
388 | num_class_images=self.num_class_images, | ||
389 | dropout=self.dropout, | ||
390 | shuffle=self.shuffle, | ||
347 | ) | 391 | ) |
348 | 392 | ||
349 | self.train_dataloader = DataLoader( | 393 | self.train_dataloader = DataLoader( |
350 | self.train_dataset, | 394 | self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
351 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
352 | ) | 395 | ) |
353 | 396 | ||
354 | if len(data_val) != 0: | 397 | if len(data_val) != 0: |
@@ -358,16 +401,24 @@ class VlpnDataModule(): | |||
358 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | 401 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) |
359 | 402 | ||
360 | self.val_dataset = VlpnDataset( | 403 | self.val_dataset = VlpnDataset( |
361 | data_val, self.tokenizer, | 404 | data_val, |
362 | num_buckets=self.num_buckets, progressive_buckets=True, | 405 | self.tokenizer, |
363 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 406 | num_buckets=self.num_buckets, |
364 | batch_size=self.batch_size, generator=self.generator, | 407 | progressive_buckets=True, |
365 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 408 | bucket_step_size=self.bucket_step_size, |
409 | bucket_max_pixels=self.bucket_max_pixels, | ||
410 | batch_size=self.batch_size, | ||
411 | generator=self.generator, | ||
412 | size=self.size, | ||
413 | interpolation=self.interpolation, | ||
414 | color_jitter=self.color_jitter, | ||
366 | ) | 415 | ) |
367 | 416 | ||
368 | self.val_dataloader = DataLoader( | 417 | self.val_dataloader = DataLoader( |
369 | self.val_dataset, | 418 | self.val_dataset, |
370 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 419 | batch_size=None, |
420 | pin_memory=True, | ||
421 | collate_fn=collate_fn_, | ||
371 | ) | 422 | ) |
372 | else: | 423 | else: |
373 | self.val_dataloader = None | 424 | self.val_dataloader = None |
@@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): | |||
418 | 469 | ||
419 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 470 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
420 | 471 | ||
421 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 472 | self.length_ = ( |
473 | (self.bucket_assignments.bincount() / self.batch_size) | ||
474 | .ceil() | ||
475 | .long() | ||
476 | .sum() | ||
477 | .item() | ||
478 | ) | ||
422 | 479 | ||
423 | def get_input_ids(self, text: str): | 480 | def get_input_ids(self, text: str): |
424 | return self.tokenizer(text, padding="do_not_pad").input_ids | 481 | return self.tokenizer(text, padding="do_not_pad").input_ids |
@@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): | |||
430 | worker_info = torch.utils.data.get_worker_info() | 487 | worker_info = torch.utils.data.get_worker_info() |
431 | 488 | ||
432 | if self.shuffle: | 489 | if self.shuffle: |
433 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) | 490 | perm = torch.randperm( |
491 | len(self.bucket_assignments), generator=self.generator | ||
492 | ) | ||
434 | self.bucket_items = self.bucket_items[perm] | 493 | self.bucket_items = self.bucket_items[perm] |
435 | self.bucket_assignments = self.bucket_assignments[perm] | 494 | self.bucket_assignments = self.bucket_assignments[perm] |
436 | 495 | ||
@@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): | |||
488 | 547 | ||
489 | if len(bucket_items) == 0: | 548 | if len(bucket_items) == 0: |
490 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | 549 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] |
491 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | 550 | item_index = bucket_items[ |
551 | torch.randint(len(bucket_items), (1,), generator=self.generator) | ||
552 | ] | ||
492 | else: | 553 | else: |
493 | item_index = bucket_items[0] | 554 | item_index = bucket_items[0] |
494 | mask[self.bucket_item_range[bucket_mask][0]] = False | 555 | mask[self.bucket_item_range[bucket_mask][0]] = False |
@@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): | |||
500 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 561 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
501 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 562 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
502 | 563 | ||
503 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) | 564 | example["instance_prompt_ids"] = self.get_input_ids( |
565 | item.full_prompt(self.dropout, True, self.npgenerator) | ||
566 | ) | ||
504 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
505 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 568 | example["instance_images"] = image_transforms( |
569 | get_image(item.instance_image_path) | ||
570 | ) | ||
506 | 571 | ||
507 | if self.num_class_images != 0: | 572 | if self.num_class_images != 0: |
508 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 573 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
509 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 574 | example["class_images"] = image_transforms( |
575 | get_image(item.class_image_path) | ||
576 | ) | ||
510 | 577 | ||
511 | batch.append(example) | 578 | batch.append(example) |
diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -8,7 +8,7 @@ def keywords_to_str( | |||
8 | undroppable_keywords: list[str] = [], | 8 | undroppable_keywords: list[str] = [], |
9 | dropout: float = 0, | 9 | dropout: float = 0, |
10 | shuffle: bool = False, | 10 | shuffle: bool = False, |
11 | npgenerator: Optional[np.random.Generator] = None | 11 | npgenerator: Optional[np.random.Generator] = None, |
12 | ) -> str: | 12 | ) -> str: |
13 | if dropout != 0: | 13 | if dropout != 0: |
14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
@@ -23,7 +23,11 @@ def keywords_to_str( | |||
23 | 23 | ||
24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: | 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
25 | def expand_keyword(keyword: str) -> list[str]: | 25 | def expand_keyword(keyword: str) -> list[str]: |
26 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 26 | return ( |
27 | [keyword] + expansions[keyword].split(", ") | ||
28 | if keyword in expansions | ||
29 | else [keyword] | ||
30 | ) | ||
27 | 31 | ||
28 | return [ | 32 | return [ |
29 | kw | 33 | kw |