diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 183 | ||||
| -rw-r--r-- | data/keywords.py | 8 |
2 files changed, 131 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -49,7 +49,7 @@ def generate_buckets( | |||
| 49 | max_pixels: Optional[int] = None, | 49 | max_pixels: Optional[int] = None, |
| 50 | num_buckets: int = 4, | 50 | num_buckets: int = 4, |
| 51 | progressive_buckets: bool = False, | 51 | progressive_buckets: bool = False, |
| 52 | return_tensor: bool = True | 52 | return_tensor: bool = True, |
| 53 | ): | 53 | ): |
| 54 | if max_pixels is None: | 54 | if max_pixels is None: |
| 55 | max_pixels = (base_size + step_size) ** 2 | 55 | max_pixels = (base_size + step_size) ** 2 |
| @@ -62,7 +62,11 @@ def generate_buckets( | |||
| 62 | 62 | ||
| 63 | for i in range(1, num_buckets + 1): | 63 | for i in range(1, num_buckets + 1): |
| 64 | long_side = base_size + i * step_size | 64 | long_side = base_size + i * step_size |
| 65 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) | 65 | short_side = min( |
| 66 | base_size | ||
| 67 | - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, | ||
| 68 | base_size, | ||
| 69 | ) | ||
| 66 | buckets.append(long_side / short_side) | 70 | buckets.append(long_side / short_side) |
| 67 | buckets.append(short_side / long_side) | 71 | buckets.append(short_side / long_side) |
| 68 | 72 | ||
| @@ -106,7 +110,7 @@ def collate_fn( | |||
| 106 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
| 107 | with_guidance: bool, | 111 | with_guidance: bool, |
| 108 | with_prior_preservation: bool, | 112 | with_prior_preservation: bool, |
| 109 | examples | 113 | examples, |
| 110 | ): | 114 | ): |
| 111 | prompt_ids = [example["prompt_ids"] for example in examples] | 115 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 116 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| @@ -125,7 +129,9 @@ def collate_fn( | |||
| 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) | 129 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
| 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) | 130 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
| 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) | 131 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
| 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) | 132 | negative_inputs = unify_input_ids( |
| 133 | tokenizer, negative_input_ids, max_token_id_length | ||
| 134 | ) | ||
| 129 | 135 | ||
| 130 | batch = { | 136 | batch = { |
| 131 | "prompt_ids": prompts.input_ids, | 137 | "prompt_ids": prompts.input_ids, |
| @@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): | |||
| 149 | nprompt: str | 155 | nprompt: str |
| 150 | collection: list[str] | 156 | collection: list[str] |
| 151 | 157 | ||
| 152 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): | 158 | def full_prompt( |
| 153 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) | 159 | self, |
| 160 | dropout: float = 0, | ||
| 161 | shuffle: bool = False, | ||
| 162 | npgenerator: Optional[np.random.Generator] = None, | ||
| 163 | ): | ||
| 164 | return keywords_to_str( | ||
| 165 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | ||
| 166 | ) | ||
| 154 | 167 | ||
| 155 | 168 | ||
| 156 | def keyword_filter( | 169 | def keyword_filter( |
| 157 | placeholder_tokens: Optional[list[str]], | 170 | placeholder_tokens: Optional[list[str]], |
| 158 | collections: Optional[list[str]], | 171 | collections: Optional[list[str]], |
| 159 | exclude_collections: Optional[list[str]], | 172 | exclude_collections: Optional[list[str]], |
| 160 | item: VlpnDataItem | 173 | item: VlpnDataItem, |
| 161 | ): | 174 | ): |
| 162 | full_prompt = item.full_prompt() | 175 | full_prompt = item.full_prompt() |
| 163 | 176 | ||
| 164 | cond1 = placeholder_tokens is None or any( | 177 | cond1 = placeholder_tokens is None or any( |
| 165 | token in full_prompt | 178 | token in full_prompt for token in placeholder_tokens |
| 166 | for token in placeholder_tokens | ||
| 167 | ) | 179 | ) |
| 168 | cond2 = collections is None or any( | 180 | cond2 = collections is None or any( |
| 169 | collection in item.collection | 181 | collection in item.collection for collection in collections |
| 170 | for collection in collections | ||
| 171 | ) | 182 | ) |
| 172 | cond3 = exclude_collections is None or not any( | 183 | cond3 = exclude_collections is None or not any( |
| 173 | collection in item.collection | 184 | collection in item.collection for collection in exclude_collections |
| 174 | for collection in exclude_collections | ||
| 175 | ) | 185 | ) |
| 176 | 186 | ||
| 177 | return cond1 and cond2 and cond3 | 187 | return cond1 and cond2 and cond3 |
| 178 | 188 | ||
| 179 | 189 | ||
| 180 | class VlpnDataModule(): | 190 | class VlpnDataModule: |
| 181 | def __init__( | 191 | def __init__( |
| 182 | self, | 192 | self, |
| 183 | batch_size: int, | 193 | batch_size: int, |
| @@ -222,7 +232,7 @@ class VlpnDataModule(): | |||
| 222 | 232 | ||
| 223 | self.constant_prompt_length = constant_prompt_length | 233 | self.constant_prompt_length = constant_prompt_length |
| 224 | self.max_token_id_length = None | 234 | self.max_token_id_length = None |
| 225 | 235 | ||
| 226 | self.tokenizer = tokenizer | 236 | self.tokenizer = tokenizer |
| 227 | self.size = size | 237 | self.size = size |
| 228 | self.num_buckets = num_buckets | 238 | self.num_buckets = num_buckets |
| @@ -259,23 +269,29 @@ class VlpnDataModule(): | |||
| 259 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | 269 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") |
| 260 | collection = item["collection"].split(", ") if "collection" in item else [] | 270 | collection = item["collection"].split(", ") if "collection" in item else [] |
| 261 | 271 | ||
| 262 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) | 272 | saturated_keywords = str_to_keywords( |
| 273 | tpl_keywords.format(**keywords), expansions | ||
| 274 | ) | ||
| 263 | 275 | ||
| 264 | inverted_tokens = keywords_to_str([ | 276 | inverted_tokens = keywords_to_str( |
| 265 | f"inv_{token}" | 277 | [ |
| 266 | for token in self.placeholder_tokens | 278 | f"inv_{token}" |
| 267 | if token in saturated_keywords | 279 | for token in self.placeholder_tokens |
| 268 | ]) | 280 | if token in saturated_keywords |
| 281 | ] | ||
| 282 | ) | ||
| 269 | 283 | ||
| 270 | items.append(VlpnDataItem( | 284 | items.append( |
| 271 | self.data_root / image, | 285 | VlpnDataItem( |
| 272 | None, | 286 | self.data_root / image, |
| 273 | saturated_keywords, | 287 | None, |
| 274 | tpl_prompt.format(**prompt), | 288 | saturated_keywords, |
| 275 | tpl_cprompt.format(**prompt), | 289 | tpl_prompt.format(**prompt), |
| 276 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | 290 | tpl_cprompt.format(**prompt), |
| 277 | collection | 291 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), |
| 278 | )) | 292 | collection, |
| 293 | ) | ||
| 294 | ) | ||
| 279 | 295 | ||
| 280 | return items | 296 | return items |
| 281 | 297 | ||
| @@ -285,13 +301,16 @@ class VlpnDataModule(): | |||
| 285 | 301 | ||
| 286 | return [item for item in items if self.filter(item)] | 302 | return [item for item in items if self.filter(item)] |
| 287 | 303 | ||
| 288 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: | 304 | def pad_items( |
| 305 | self, items: list[VlpnDataItem], num_class_images: int = 1 | ||
| 306 | ) -> list[VlpnDataItem]: | ||
| 289 | image_multiplier = max(num_class_images, 1) | 307 | image_multiplier = max(num_class_images, 1) |
| 290 | 308 | ||
| 291 | return [ | 309 | return [ |
| 292 | VlpnDataItem( | 310 | VlpnDataItem( |
| 293 | item.instance_image_path, | 311 | item.instance_image_path, |
| 294 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 312 | self.class_root |
| 313 | / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | ||
| 295 | item.keywords, | 314 | item.keywords, |
| 296 | item.prompt, | 315 | item.prompt, |
| 297 | item.cprompt, | 316 | item.cprompt, |
| @@ -303,7 +322,7 @@ class VlpnDataModule(): | |||
| 303 | ] | 322 | ] |
| 304 | 323 | ||
| 305 | def setup(self): | 324 | def setup(self): |
| 306 | with open(self.data_file, 'rt') as f: | 325 | with open(self.data_file, "rt") as f: |
| 307 | metadata = json.load(f) | 326 | metadata = json.load(f) |
| 308 | template = metadata[self.template_key] if self.template_key in metadata else {} | 327 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 309 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 328 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
| @@ -312,25 +331,41 @@ class VlpnDataModule(): | |||
| 312 | items = self.prepare_items(template, expansions, items) | 331 | items = self.prepare_items(template, expansions, items) |
| 313 | items = self.filter_items(items) | 332 | items = self.filter_items(items) |
| 314 | self.npgenerator.shuffle(items) | 333 | self.npgenerator.shuffle(items) |
| 315 | 334 | ||
| 316 | if self.constant_prompt_length: | 335 | if self.constant_prompt_length: |
| 317 | all_input_ids = unify_input_ids( | 336 | all_input_ids = unify_input_ids( |
| 318 | self.tokenizer, | 337 | self.tokenizer, |
| 319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | 338 | [ |
| 339 | self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids | ||
| 340 | for item in items | ||
| 341 | ], | ||
| 320 | ).input_ids | 342 | ).input_ids |
| 321 | self.max_token_id_length = all_input_ids.shape[1] | 343 | self.max_token_id_length = all_input_ids.shape[1] |
| 322 | 344 | ||
| 323 | num_images = len(items) | 345 | num_images = len(items) |
| 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 346 | valid_set_size = ( |
| 347 | min(self.valid_set_size, num_images) | ||
| 348 | if self.valid_set_size is not None | ||
| 349 | else num_images // 10 | ||
| 350 | ) | ||
| 325 | train_set_size = max(num_images - valid_set_size, 1) | 351 | train_set_size = max(num_images - valid_set_size, 1) |
| 326 | valid_set_size = num_images - train_set_size | 352 | valid_set_size = num_images - train_set_size |
| 327 | 353 | ||
| 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) | 354 | collate_fn_ = partial( |
| 355 | collate_fn, | ||
| 356 | self.dtype, | ||
| 357 | self.tokenizer, | ||
| 358 | self.max_token_id_length, | ||
| 359 | self.with_guidance, | ||
| 360 | self.num_class_images != 0, | ||
| 361 | ) | ||
| 329 | 362 | ||
| 330 | if valid_set_size == 0: | 363 | if valid_set_size == 0: |
| 331 | data_train, data_val = items, items | 364 | data_train, data_val = items, items |
| 332 | else: | 365 | else: |
| 333 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) | 366 | data_train, data_val = random_split( |
| 367 | items, [train_set_size, valid_set_size], generator=self.generator | ||
| 368 | ) | ||
| 334 | 369 | ||
| 335 | data_train = self.pad_items(data_train, self.num_class_images) | 370 | data_train = self.pad_items(data_train, self.num_class_images) |
| 336 | 371 | ||
| @@ -338,17 +373,25 @@ class VlpnDataModule(): | |||
| 338 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | 373 | data_train *= math.ceil(self.train_set_pad / len(data_train)) |
| 339 | 374 | ||
| 340 | self.train_dataset = VlpnDataset( | 375 | self.train_dataset = VlpnDataset( |
| 341 | data_train, self.tokenizer, | 376 | data_train, |
| 342 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 377 | self.tokenizer, |
| 343 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 378 | num_buckets=self.num_buckets, |
| 344 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, | 379 | progressive_buckets=self.progressive_buckets, |
| 345 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 380 | bucket_step_size=self.bucket_step_size, |
| 346 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 381 | bucket_max_pixels=self.bucket_max_pixels, |
| 382 | batch_size=self.batch_size, | ||
| 383 | fill_batch=True, | ||
| 384 | generator=self.generator, | ||
| 385 | size=self.size, | ||
| 386 | interpolation=self.interpolation, | ||
| 387 | color_jitter=self.color_jitter, | ||
| 388 | num_class_images=self.num_class_images, | ||
| 389 | dropout=self.dropout, | ||
| 390 | shuffle=self.shuffle, | ||
| 347 | ) | 391 | ) |
| 348 | 392 | ||
| 349 | self.train_dataloader = DataLoader( | 393 | self.train_dataloader = DataLoader( |
| 350 | self.train_dataset, | 394 | self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 351 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
| 352 | ) | 395 | ) |
| 353 | 396 | ||
| 354 | if len(data_val) != 0: | 397 | if len(data_val) != 0: |
| @@ -358,16 +401,24 @@ class VlpnDataModule(): | |||
| 358 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | 401 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) |
| 359 | 402 | ||
| 360 | self.val_dataset = VlpnDataset( | 403 | self.val_dataset = VlpnDataset( |
| 361 | data_val, self.tokenizer, | 404 | data_val, |
| 362 | num_buckets=self.num_buckets, progressive_buckets=True, | 405 | self.tokenizer, |
| 363 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 406 | num_buckets=self.num_buckets, |
| 364 | batch_size=self.batch_size, generator=self.generator, | 407 | progressive_buckets=True, |
| 365 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 408 | bucket_step_size=self.bucket_step_size, |
| 409 | bucket_max_pixels=self.bucket_max_pixels, | ||
| 410 | batch_size=self.batch_size, | ||
| 411 | generator=self.generator, | ||
| 412 | size=self.size, | ||
| 413 | interpolation=self.interpolation, | ||
| 414 | color_jitter=self.color_jitter, | ||
| 366 | ) | 415 | ) |
| 367 | 416 | ||
| 368 | self.val_dataloader = DataLoader( | 417 | self.val_dataloader = DataLoader( |
| 369 | self.val_dataset, | 418 | self.val_dataset, |
| 370 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 419 | batch_size=None, |
| 420 | pin_memory=True, | ||
| 421 | collate_fn=collate_fn_, | ||
| 371 | ) | 422 | ) |
| 372 | else: | 423 | else: |
| 373 | self.val_dataloader = None | 424 | self.val_dataloader = None |
| @@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): | |||
| 418 | 469 | ||
| 419 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 470 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
| 420 | 471 | ||
| 421 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 472 | self.length_ = ( |
| 473 | (self.bucket_assignments.bincount() / self.batch_size) | ||
| 474 | .ceil() | ||
| 475 | .long() | ||
| 476 | .sum() | ||
| 477 | .item() | ||
| 478 | ) | ||
| 422 | 479 | ||
| 423 | def get_input_ids(self, text: str): | 480 | def get_input_ids(self, text: str): |
| 424 | return self.tokenizer(text, padding="do_not_pad").input_ids | 481 | return self.tokenizer(text, padding="do_not_pad").input_ids |
| @@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): | |||
| 430 | worker_info = torch.utils.data.get_worker_info() | 487 | worker_info = torch.utils.data.get_worker_info() |
| 431 | 488 | ||
| 432 | if self.shuffle: | 489 | if self.shuffle: |
| 433 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) | 490 | perm = torch.randperm( |
| 491 | len(self.bucket_assignments), generator=self.generator | ||
| 492 | ) | ||
| 434 | self.bucket_items = self.bucket_items[perm] | 493 | self.bucket_items = self.bucket_items[perm] |
| 435 | self.bucket_assignments = self.bucket_assignments[perm] | 494 | self.bucket_assignments = self.bucket_assignments[perm] |
| 436 | 495 | ||
| @@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): | |||
| 488 | 547 | ||
| 489 | if len(bucket_items) == 0: | 548 | if len(bucket_items) == 0: |
| 490 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | 549 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] |
| 491 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | 550 | item_index = bucket_items[ |
| 551 | torch.randint(len(bucket_items), (1,), generator=self.generator) | ||
| 552 | ] | ||
| 492 | else: | 553 | else: |
| 493 | item_index = bucket_items[0] | 554 | item_index = bucket_items[0] |
| 494 | mask[self.bucket_item_range[bucket_mask][0]] = False | 555 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| @@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): | |||
| 500 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 561 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
| 501 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 562 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 502 | 563 | ||
| 503 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) | 564 | example["instance_prompt_ids"] = self.get_input_ids( |
| 565 | item.full_prompt(self.dropout, True, self.npgenerator) | ||
| 566 | ) | ||
| 504 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
| 505 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 568 | example["instance_images"] = image_transforms( |
| 569 | get_image(item.instance_image_path) | ||
| 570 | ) | ||
| 506 | 571 | ||
| 507 | if self.num_class_images != 0: | 572 | if self.num_class_images != 0: |
| 508 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 573 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
| 509 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 574 | example["class_images"] = image_transforms( |
| 575 | get_image(item.class_image_path) | ||
| 576 | ) | ||
| 510 | 577 | ||
| 511 | batch.append(example) | 578 | batch.append(example) |
diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -8,7 +8,7 @@ def keywords_to_str( | |||
| 8 | undroppable_keywords: list[str] = [], | 8 | undroppable_keywords: list[str] = [], |
| 9 | dropout: float = 0, | 9 | dropout: float = 0, |
| 10 | shuffle: bool = False, | 10 | shuffle: bool = False, |
| 11 | npgenerator: Optional[np.random.Generator] = None | 11 | npgenerator: Optional[np.random.Generator] = None, |
| 12 | ) -> str: | 12 | ) -> str: |
| 13 | if dropout != 0: | 13 | if dropout != 0: |
| 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| @@ -23,7 +23,11 @@ def keywords_to_str( | |||
| 23 | 23 | ||
| 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: | 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
| 25 | def expand_keyword(keyword: str) -> list[str]: | 25 | def expand_keyword(keyword: str) -> list[str]: |
| 26 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 26 | return ( |
| 27 | [keyword] + expansions[keyword].split(", ") | ||
| 28 | if keyword in expansions | ||
| 29 | else [keyword] | ||
| 30 | ) | ||
| 27 | 31 | ||
| 28 | return [ | 32 | return [ |
| 29 | kw | 33 | kw |
