diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/dreambooth.py b/dreambooth.py index 5c26f12..2c24908 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -71,13 +71,13 @@ def parse_args(): | |||
| 71 | parser.add_argument( | 71 | parser.add_argument( |
| 72 | "--placeholder_token", | 72 | "--placeholder_token", |
| 73 | type=str, | 73 | type=str, |
| 74 | default="<*>", | 74 | nargs='*', |
| 75 | help="A token to use as a placeholder for the concept.", | 75 | help="A token to use as a placeholder for the concept.", |
| 76 | ) | 76 | ) |
| 77 | parser.add_argument( | 77 | parser.add_argument( |
| 78 | "--initializer_token", | 78 | "--initializer_token", |
| 79 | type=str, | 79 | type=str, |
| 80 | default=None, | 80 | nargs='*', |
| 81 | help="A token to use as initializer word." | 81 | help="A token to use as initializer word." |
| 82 | ) | 82 | ) |
| 83 | parser.add_argument( | 83 | parser.add_argument( |
| @@ -316,6 +316,18 @@ def parse_args(): | |||
| 316 | if args.instance_identifier is None: | 316 | if args.instance_identifier is None: |
| 317 | raise ValueError("You must specify --instance_identifier") | 317 | raise ValueError("You must specify --instance_identifier") |
| 318 | 318 | ||
| 319 | if isinstance(args.initializer_token, str): | ||
| 320 | args.initializer_token = [args.initializer_token] | ||
| 321 | |||
| 322 | if isinstance(args.placeholder_token, str): | ||
| 323 | args.placeholder_token = [args.placeholder_token] | ||
| 324 | |||
| 325 | if len(args.placeholder_token) == 0: | ||
| 326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
| 327 | |||
| 328 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | ||
| 330 | |||
| 319 | if args.output_dir is None: | 331 | if args.output_dir is None: |
| 320 | raise ValueError("You must specify --output_dir") | 332 | raise ValueError("You must specify --output_dir") |
| 321 | 333 | ||
| @@ -379,9 +391,6 @@ class Checkpointer: | |||
| 379 | 391 | ||
| 380 | @torch.no_grad() | 392 | @torch.no_grad() |
| 381 | def save_embedding(self, step, postfix): | 393 | def save_embedding(self, step, postfix): |
| 382 | if self.placeholder_token_id is None: | ||
| 383 | return | ||
| 384 | |||
| 385 | print("Saving checkpoint for step %d..." % step) | 394 | print("Saving checkpoint for step %d..." % step) |
| 386 | 395 | ||
| 387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| @@ -389,12 +398,13 @@ class Checkpointer: | |||
| 389 | 398 | ||
| 390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 399 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 391 | 400 | ||
| 392 | # Save a checkpoint | 401 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 402 | # Save a checkpoint |
| 394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 403 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
| 404 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
| 395 | 405 | ||
| 396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| 397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 398 | 408 | ||
| 399 | del unwrapped | 409 | del unwrapped |
| 400 | del learned_embeds | 410 | del learned_embeds |
| @@ -467,7 +477,7 @@ class Checkpointer: | |||
| 467 | for i in range(self.sample_batches): | 477 | for i in range(self.sample_batches): |
| 468 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 478 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 469 | prompt = [ | 479 | prompt = [ |
| 470 | prompt.format(self.instance_identifier) | 480 | prompt.format(identifier=self.instance_identifier) |
| 471 | for batch in batches | 481 | for batch in batches |
| 472 | for prompt in batch["prompts"] | 482 | for prompt in batch["prompts"] |
| 473 | ][:self.sample_batch_size] | 483 | ][:self.sample_batch_size] |
| @@ -516,8 +526,8 @@ def main(): | |||
| 516 | 526 | ||
| 517 | instance_identifier = args.instance_identifier | 527 | instance_identifier = args.instance_identifier |
| 518 | 528 | ||
| 519 | if args.placeholder_token is not None: | 529 | if len(args.placeholder_token) != 0: |
| 520 | instance_identifier = instance_identifier.format(args.placeholder_token) | 530 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
| 521 | 531 | ||
| 522 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 532 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 523 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 533 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| @@ -565,18 +575,16 @@ def main(): | |||
| 565 | # Freeze text_encoder and vae | 575 | # Freeze text_encoder and vae |
| 566 | freeze_params(vae.parameters()) | 576 | freeze_params(vae.parameters()) |
| 567 | 577 | ||
| 568 | if args.initializer_token is not None: | 578 | if len(args.initializer_token) != 0: |
| 569 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
| 570 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 580 | initializer_token_ids = torch.stack([ |
| 571 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| 572 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 582 | for token in args.initializer_token |
| 583 | ]) | ||
| 573 | 584 | ||
| 574 | # Add the placeholder token in tokenizer | ||
| 575 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 585 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 576 | if num_added_tokens == 0: | 586 | print(f"Added {num_added_tokens} new tokens.") |
| 577 | print(f"Re-using existing token {args.placeholder_token}.") | 587 | |
| 578 | else: | ||
| 579 | print(f"Training new token {args.placeholder_token}.") | ||
| 580 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 588 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 581 | 589 | ||
| 582 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 590 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| @@ -586,7 +594,9 @@ def main(): | |||
| 586 | token_embeds = text_encoder.get_input_embeddings().weight.data | 594 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 587 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 595 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 588 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 596 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 589 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 597 | |
| 598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 599 | token_embeds[token_id] = embeddings | ||
| 590 | 600 | ||
| 591 | freeze_params(itertools.chain( | 601 | freeze_params(itertools.chain( |
| 592 | text_encoder.text_model.encoder.parameters(), | 602 | text_encoder.text_model.encoder.parameters(), |
| @@ -594,7 +604,7 @@ def main(): | |||
| 594 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 595 | )) | 605 | )) |
| 596 | else: | 606 | else: |
| 597 | placeholder_token_id = None | 607 | placeholder_token_id = [] |
| 598 | 608 | ||
| 599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 600 | 610 | ||
| @@ -721,7 +731,7 @@ def main(): | |||
| 721 | with torch.inference_mode(): | 731 | with torch.inference_mode(): |
| 722 | for batch in batched_data: | 732 | for batch in batched_data: |
| 723 | image_name = [item.class_image_path for item in batch] | 733 | image_name = [item.class_image_path for item in batch] |
| 724 | prompt = [item.prompt.format(args.class_identifier) for item in batch] | 734 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
| 725 | nprompt = [item.nprompt for item in batch] | 735 | nprompt = [item.nprompt for item in batch] |
| 726 | 736 | ||
| 727 | images = pipeline( | 737 | images = pipeline( |
| @@ -787,7 +797,10 @@ def main(): | |||
| 787 | # We need to initialize the trackers we use, and also store our configuration. | 797 | # We need to initialize the trackers we use, and also store our configuration. |
| 788 | # The trackers initializes automatically on the main process. | 798 | # The trackers initializes automatically on the main process. |
| 789 | if accelerator.is_main_process: | 799 | if accelerator.is_main_process: |
| 790 | accelerator.init_trackers("dreambooth", config=vars(args)) | 800 | config = vars(args).copy() |
| 801 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
| 802 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
| 803 | accelerator.init_trackers("dreambooth", config=config) | ||
| 791 | 804 | ||
| 792 | # Train! | 805 | # Train! |
| 793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 806 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| @@ -932,6 +945,9 @@ def main(): | |||
| 932 | global_step += 1 | 945 | global_step += 1 |
| 933 | 946 | ||
| 934 | if global_step % args.sample_frequency == 0: | 947 | if global_step % args.sample_frequency == 0: |
| 948 | local_progress_bar.clear() | ||
| 949 | global_progress_bar.clear() | ||
| 950 | |||
| 935 | checkpointer.save_embedding(global_step, "training") | 951 | checkpointer.save_embedding(global_step, "training") |
| 936 | sample_checkpoint = True | 952 | sample_checkpoint = True |
| 937 | 953 | ||
