diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-19 17:45:56 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-19 17:45:56 +0200 |
| commit | 5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa (patch) | |
| tree | fdf3c57c2f238e5381ad5e961e704fedd3e6816a /dreambooth.py | |
| parent | Updated Dreambooth training (diff) | |
| download | textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.gz textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.bz2 textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.zip | |
Dreambooth: Added option to insert a new input token; removed Dreambooth Plus
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 48 |
1 files changed, 42 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index d1cf535..da8399f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -69,6 +69,18 @@ def parse_args(): | |||
| 69 | help="A token to use as a placeholder for the concept.", | 69 | help="A token to use as a placeholder for the concept.", |
| 70 | ) | 70 | ) |
| 71 | parser.add_argument( | 71 | parser.add_argument( |
| 72 | "--placeholder_token", | ||
| 73 | type=str, | ||
| 74 | default="<*>", | ||
| 75 | help="A token to use as a placeholder for the concept.", | ||
| 76 | ) | ||
| 77 | parser.add_argument( | ||
| 78 | "--initializer_token", | ||
| 79 | type=str, | ||
| 80 | default=None, | ||
| 81 | help="A token to use as initializer word." | ||
| 82 | ) | ||
| 83 | parser.add_argument( | ||
| 72 | "--num_class_images", | 84 | "--num_class_images", |
| 73 | type=int, | 85 | type=int, |
| 74 | default=400, | 86 | default=400, |
| @@ -114,7 +126,7 @@ def parse_args(): | |||
| 114 | parser.add_argument( | 126 | parser.add_argument( |
| 115 | "--max_train_steps", | 127 | "--max_train_steps", |
| 116 | type=int, | 128 | type=int, |
| 117 | default=3600, | 129 | default=6000, |
| 118 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 130 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 119 | ) | 131 | ) |
| 120 | parser.add_argument( | 132 | parser.add_argument( |
| @@ -131,13 +143,13 @@ def parse_args(): | |||
| 131 | parser.add_argument( | 143 | parser.add_argument( |
| 132 | "--learning_rate_unet", | 144 | "--learning_rate_unet", |
| 133 | type=float, | 145 | type=float, |
| 134 | default=3e-6, | 146 | default=2e-6, |
| 135 | help="Initial learning rate (after the potential warmup period) to use.", | 147 | help="Initial learning rate (after the potential warmup period) to use.", |
| 136 | ) | 148 | ) |
| 137 | parser.add_argument( | 149 | parser.add_argument( |
| 138 | "--learning_rate_text", | 150 | "--learning_rate_text", |
| 139 | type=float, | 151 | type=float, |
| 140 | default=3e-6, | 152 | default=2e-6, |
| 141 | help="Initial learning rate (after the potential warmup period) to use.", | 153 | help="Initial learning rate (after the potential warmup period) to use.", |
| 142 | ) | 154 | ) |
| 143 | parser.add_argument( | 155 | parser.add_argument( |
| @@ -476,8 +488,13 @@ class Checkpointer: | |||
| 476 | def main(): | 488 | def main(): |
| 477 | args = parse_args() | 489 | args = parse_args() |
| 478 | 490 | ||
| 491 | instance_identifier = args.instance_identifier | ||
| 492 | |||
| 493 | if args.placeholder_token is not None: | ||
| 494 | instance_identifier = instance_identifier.format(args.placeholder_token) | ||
| 495 | |||
| 479 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 496 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 480 | basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) | 497 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 481 | basepath.mkdir(parents=True, exist_ok=True) | 498 | basepath.mkdir(parents=True, exist_ok=True) |
| 482 | 499 | ||
| 483 | accelerator = Accelerator( | 500 | accelerator = Accelerator( |
| @@ -514,6 +531,25 @@ def main(): | |||
| 514 | device=accelerator.device | 531 | device=accelerator.device |
| 515 | ) if args.use_ema else None | 532 | ) if args.use_ema else None |
| 516 | 533 | ||
| 534 | if args.initializer_token is not None: | ||
| 535 | # Convert the initializer_token, placeholder_token to ids | ||
| 536 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 537 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | ||
| 538 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
| 539 | |||
| 540 | # Add the placeholder token in tokenizer | ||
| 541 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 542 | if num_added_tokens == 0: | ||
| 543 | print(f"Re-using existing token {args.placeholder_token}.") | ||
| 544 | else: | ||
| 545 | print(f"Training new token {args.placeholder_token}.") | ||
| 546 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 547 | |||
| 548 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 549 | token_embeds = text_encoder.get_input_embeddings() | ||
| 550 | initializer_token_embeddings = token_embeds(initializer_token_ids) | ||
| 551 | token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings | ||
| 552 | |||
| 517 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 553 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 518 | 554 | ||
| 519 | if args.gradient_checkpointing: | 555 | if args.gradient_checkpointing: |
| @@ -605,7 +641,7 @@ def main(): | |||
| 605 | data_file=args.train_data_file, | 641 | data_file=args.train_data_file, |
| 606 | batch_size=args.train_batch_size, | 642 | batch_size=args.train_batch_size, |
| 607 | prompt_processor=prompt_processor, | 643 | prompt_processor=prompt_processor, |
| 608 | instance_identifier=args.instance_identifier, | 644 | instance_identifier=instance_identifier, |
| 609 | class_identifier=args.class_identifier, | 645 | class_identifier=args.class_identifier, |
| 610 | class_subdir="cls", | 646 | class_subdir="cls", |
| 611 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
| @@ -735,7 +771,7 @@ def main(): | |||
| 735 | tokenizer=tokenizer, | 771 | tokenizer=tokenizer, |
| 736 | text_encoder=text_encoder, | 772 | text_encoder=text_encoder, |
| 737 | output_dir=basepath, | 773 | output_dir=basepath, |
| 738 | instance_identifier=args.instance_identifier, | 774 | instance_identifier=instance_identifier, |
| 739 | sample_image_size=args.sample_image_size, | 775 | sample_image_size=args.sample_image_size, |
| 740 | sample_batch_size=args.sample_batch_size, | 776 | sample_batch_size=args.sample_batch_size, |
| 741 | sample_batches=args.sample_batches, | 777 | sample_batches=args.sample_batches, |
