diff options
| -rw-r--r-- | textual_dreambooth.py | 107 |
1 files changed, 38 insertions, 69 deletions
diff --git a/textual_dreambooth.py b/textual_dreambooth.py index a46953d..c07d98b 100644 --- a/textual_dreambooth.py +++ b/textual_dreambooth.py | |||
| @@ -68,20 +68,6 @@ def parse_args(): | |||
| 68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
| 69 | ) | 69 | ) |
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--num_vec_per_token", | ||
| 72 | type=int, | ||
| 73 | default=1, | ||
| 74 | help=( | ||
| 75 | "The number of vectors used to represent the placeholder token. The higher the number, the better the" | ||
| 76 | " result at the cost of editability. This can be fixed by prompt editing." | ||
| 77 | ), | ||
| 78 | ) | ||
| 79 | parser.add_argument( | ||
| 80 | "--initialize_rest_random", | ||
| 81 | action="store_true", | ||
| 82 | help="Initialize rest of the placeholder tokens with random." | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 85 | "--use_class_images", | 71 | "--use_class_images", |
| 86 | action="store_true", | 72 | action="store_true", |
| 87 | default=True, | 73 | default=True, |
| @@ -324,40 +310,6 @@ def make_grid(images, rows, cols): | |||
| 324 | return grid | 310 | return grid |
| 325 | 311 | ||
| 326 | 312 | ||
| 327 | def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder): | ||
| 328 | assert args.num_vec_per_token >= len(token_ids) | ||
| 329 | placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)] | ||
| 330 | |||
| 331 | for placeholder_token in placeholder_tokens: | ||
| 332 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | ||
| 333 | if num_added_tokens == 0: | ||
| 334 | raise ValueError( | ||
| 335 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
| 336 | " `placeholder_token` that is not already in the tokenizer." | ||
| 337 | ) | ||
| 338 | |||
| 339 | placeholder_token = " ".join(placeholder_tokens) | ||
| 340 | placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) | ||
| 341 | |||
| 342 | print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}") | ||
| 343 | |||
| 344 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 345 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 346 | |||
| 347 | if args.initialize_rest_random: | ||
| 348 | # The idea is that the placeholder tokens form adjectives as in x x x white dog. | ||
| 349 | for i, placeholder_token_id in enumerate(placeholder_token_ids): | ||
| 350 | if len(placeholder_token_ids) - i < len(token_ids): | ||
| 351 | token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] | ||
| 352 | else: | ||
| 353 | token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id]) | ||
| 354 | else: | ||
| 355 | for i, placeholder_token_id in enumerate(placeholder_token_ids): | ||
| 356 | token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] | ||
| 357 | |||
| 358 | return placeholder_token, placeholder_token_ids | ||
| 359 | |||
| 360 | |||
| 361 | class Checkpointer: | 313 | class Checkpointer: |
| 362 | def __init__( | 314 | def __init__( |
| 363 | self, | 315 | self, |
| @@ -367,7 +319,7 @@ class Checkpointer: | |||
| 367 | unet, | 319 | unet, |
| 368 | tokenizer, | 320 | tokenizer, |
| 369 | placeholder_token, | 321 | placeholder_token, |
| 370 | placeholder_token_ids, | 322 | placeholder_token_id, |
| 371 | output_dir, | 323 | output_dir, |
| 372 | sample_image_size, | 324 | sample_image_size, |
| 373 | sample_batches, | 325 | sample_batches, |
| @@ -380,7 +332,7 @@ class Checkpointer: | |||
| 380 | self.unet = unet | 332 | self.unet = unet |
| 381 | self.tokenizer = tokenizer | 333 | self.tokenizer = tokenizer |
| 382 | self.placeholder_token = placeholder_token | 334 | self.placeholder_token = placeholder_token |
| 383 | self.placeholder_token_ids = placeholder_token_ids | 335 | self.placeholder_token_id = placeholder_token_id |
| 384 | self.output_dir = output_dir | 336 | self.output_dir = output_dir |
| 385 | self.sample_image_size = sample_image_size | 337 | self.sample_image_size = sample_image_size |
| 386 | self.seed = seed or torch.random.seed() | 338 | self.seed = seed or torch.random.seed() |
| @@ -398,10 +350,8 @@ class Checkpointer: | |||
| 398 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 350 | unwrapped = self.accelerator.unwrap_model(text_encoder) |
| 399 | 351 | ||
| 400 | # Save a checkpoint | 352 | # Save a checkpoint |
| 401 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids] | 353 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| 402 | learned_embeds_dict = {} | 354 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
| 403 | for i, placeholder_token in enumerate(self.placeholder_token.split(" ")): | ||
| 404 | learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu() | ||
| 405 | 355 | ||
| 406 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 356 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
| 407 | if path is not None: | 357 | if path is not None: |
| @@ -527,6 +477,24 @@ def main(): | |||
| 527 | args.pretrained_model_name_or_path + '/tokenizer' | 477 | args.pretrained_model_name_or_path + '/tokenizer' |
| 528 | ) | 478 | ) |
| 529 | 479 | ||
| 480 | # Add the placeholder token in tokenizer | ||
| 481 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 482 | if num_added_tokens == 0: | ||
| 483 | raise ValueError( | ||
| 484 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | ||
| 485 | " `placeholder_token` that is not already in the tokenizer." | ||
| 486 | ) | ||
| 487 | |||
| 488 | # Convert the initializer_token, placeholder_token to ids | ||
| 489 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 490 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 491 | if len(initializer_token_ids) > 1: | ||
| 492 | raise ValueError( | ||
| 493 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
| 494 | |||
| 495 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 496 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 497 | |||
| 530 | # Load models and create wrapper for stable diffusion | 498 | # Load models and create wrapper for stable diffusion |
| 531 | text_encoder = CLIPTextModel.from_pretrained( | 499 | text_encoder = CLIPTextModel.from_pretrained( |
| 532 | args.pretrained_model_name_or_path + '/text_encoder', | 500 | args.pretrained_model_name_or_path + '/text_encoder', |
| @@ -544,16 +512,19 @@ def main(): | |||
| 544 | slice_size = unet.config.attention_head_dim // 2 | 512 | slice_size = unet.config.attention_head_dim // 2 |
| 545 | unet.set_attention_slice(slice_size) | 513 | unet.set_attention_slice(slice_size) |
| 546 | 514 | ||
| 547 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 515 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 548 | # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating. | 516 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 549 | placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token( | ||
| 550 | args, token_ids, tokenizer, text_encoder) | ||
| 551 | 517 | ||
| 552 | # if args.resume_checkpoint is not None: | 518 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 553 | # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 519 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 554 | # args.placeholder_token] | 520 | |
| 555 | # else: | 521 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 556 | # token_embeds[placeholder_token_id] = initializer_token_embeddings | 522 | |
| 523 | if args.resume_checkpoint is not None: | ||
| 524 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | ||
| 525 | args.placeholder_token] | ||
| 526 | else: | ||
| 527 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
| 557 | 528 | ||
| 558 | # Freeze vae and unet | 529 | # Freeze vae and unet |
| 559 | freeze_params(vae.parameters()) | 530 | freeze_params(vae.parameters()) |
| @@ -627,7 +598,7 @@ def main(): | |||
| 627 | data_file=args.train_data_file, | 598 | data_file=args.train_data_file, |
| 628 | batch_size=args.train_batch_size, | 599 | batch_size=args.train_batch_size, |
| 629 | tokenizer=tokenizer, | 600 | tokenizer=tokenizer, |
| 630 | instance_identifier=placeholder_token, | 601 | instance_identifier=args.placeholder_token, |
| 631 | class_identifier=args.initializer_token if args.use_class_images else None, | 602 | class_identifier=args.initializer_token if args.use_class_images else None, |
| 632 | class_subdir="ti_cls", | 603 | class_subdir="ti_cls", |
| 633 | size=args.resolution, | 604 | size=args.resolution, |
| @@ -690,7 +661,7 @@ def main(): | |||
| 690 | unet=unet, | 661 | unet=unet, |
| 691 | tokenizer=tokenizer, | 662 | tokenizer=tokenizer, |
| 692 | placeholder_token=args.placeholder_token, | 663 | placeholder_token=args.placeholder_token, |
| 693 | placeholder_token_ids=placeholder_token_ids, | 664 | placeholder_token_id=placeholder_token_id, |
| 694 | output_dir=basepath, | 665 | output_dir=basepath, |
| 695 | sample_image_size=args.sample_image_size, | 666 | sample_image_size=args.sample_image_size, |
| 696 | sample_batch_size=args.sample_batch_size, | 667 | sample_batch_size=args.sample_batch_size, |
| @@ -823,10 +794,8 @@ def main(): | |||
| 823 | else: | 794 | else: |
| 824 | grads = text_encoder.get_input_embeddings().weight.grad | 795 | grads = text_encoder.get_input_embeddings().weight.grad |
| 825 | # Get the index for tokens that we want to zero the grads for | 796 | # Get the index for tokens that we want to zero the grads for |
| 826 | grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0] | 797 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id |
| 827 | for i in range(1, len(placeholder_token_ids)): | 798 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) |
| 828 | grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]) | ||
| 829 | grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0) | ||
| 830 | 799 | ||
| 831 | optimizer.step() | 800 | optimizer.step() |
| 832 | if not accelerator.optimizer_step_was_skipped: | 801 | if not accelerator.optimizer_step_was_skipped: |
