diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
| commit | 36440e48ce279872d6e736bcb1bf57d13da73a11 (patch) | |
| tree | 8ba9593d8a887517c70b01932c137c9c3f759e8f /train_dreambooth.py | |
| parent | More training adjustments (diff) | |
| download | textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.gz textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.bz2 textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.zip | |
Moved multi-TI code from Dreambooth to TI script
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 135 |
1 files changed, 2 insertions, 133 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -200,23 +200,6 @@ def parse_args(): | |||
| 200 | default=100 | 200 | default=100 |
| 201 | ) | 201 | ) |
| 202 | parser.add_argument( | 202 | parser.add_argument( |
| 203 | "--ti_data_template", | ||
| 204 | type=str, | ||
| 205 | nargs='*', | ||
| 206 | default=[], | ||
| 207 | ) | ||
| 208 | parser.add_argument( | ||
| 209 | "--ti_num_train_epochs", | ||
| 210 | type=int, | ||
| 211 | default=10 | ||
| 212 | ) | ||
| 213 | parser.add_argument( | ||
| 214 | "--ti_batch_size", | ||
| 215 | type=int, | ||
| 216 | default=1, | ||
| 217 | help="Batch size (per device) for the training dataloader." | ||
| 218 | ) | ||
| 219 | parser.add_argument( | ||
| 220 | "--max_train_steps", | 203 | "--max_train_steps", |
| 221 | type=int, | 204 | type=int, |
| 222 | default=None, | 205 | default=None, |
| @@ -245,12 +228,6 @@ def parse_args(): | |||
| 245 | help="Initial learning rate (after the potential warmup period) to use.", | 228 | help="Initial learning rate (after the potential warmup period) to use.", |
| 246 | ) | 229 | ) |
| 247 | parser.add_argument( | 230 | parser.add_argument( |
| 248 | "--ti_learning_rate", | ||
| 249 | type=float, | ||
| 250 | default=1e-2, | ||
| 251 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 252 | ) | ||
| 253 | parser.add_argument( | ||
| 254 | "--scale_lr", | 231 | "--scale_lr", |
| 255 | action="store_true", | 232 | action="store_true", |
| 256 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 233 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -482,12 +459,6 @@ def parse_args(): | |||
| 482 | if len(args.placeholder_tokens) != len(args.num_vectors): | 459 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 460 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 484 | 461 | ||
| 485 | if isinstance(args.ti_data_template, str): | ||
| 486 | args.ti_data_template = [args.ti_data_template] | ||
| 487 | |||
| 488 | if len(args.ti_data_template) == 0: | ||
| 489 | raise ValueError("You must specify --ti_data_template") | ||
| 490 | |||
| 491 | if isinstance(args.collection, str): | 462 | if isinstance(args.collection, str): |
| 492 | args.collection = [args.collection] | 463 | args.collection = [args.collection] |
| 493 | 464 | ||
| @@ -521,8 +492,6 @@ def main(): | |||
| 521 | 492 | ||
| 522 | set_seed(args.seed) | 493 | set_seed(args.seed) |
| 523 | 494 | ||
| 524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
| 525 | |||
| 526 | save_args(output_dir, args) | 495 | save_args(output_dir, args) |
| 527 | 496 | ||
| 528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 497 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| @@ -583,107 +552,6 @@ def main(): | |||
| 583 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
| 584 | ) | 553 | ) |
| 585 | 554 | ||
| 586 | # Initial TI | ||
| 587 | |||
| 588 | print("Phase 1: Textual Inversion") | ||
| 589 | |||
| 590 | cur_dir = output_dir.joinpath("1-ti") | ||
| 591 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
| 592 | |||
| 593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
| 594 | range(len(args.placeholder_tokens)), | ||
| 595 | args.placeholder_tokens, | ||
| 596 | args.initializer_tokens, | ||
| 597 | args.num_vectors, | ||
| 598 | args.ti_data_template | ||
| 599 | ): | ||
| 600 | cur_subdir = cur_dir.joinpath(placeholder_token) | ||
| 601 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
| 602 | |||
| 603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 604 | tokenizer=tokenizer, | ||
| 605 | embeddings=embeddings, | ||
| 606 | placeholder_tokens=[placeholder_token], | ||
| 607 | initializer_tokens=[initializer_token], | ||
| 608 | num_vectors=[num_vectors] | ||
| 609 | ) | ||
| 610 | |||
| 611 | print( | ||
| 612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
| 613 | |||
| 614 | args.seed = seed_generator.seed() | ||
| 615 | |||
| 616 | datamodule = VlpnDataModule( | ||
| 617 | data_file=args.train_data_file, | ||
| 618 | batch_size=args.ti_batch_size, | ||
| 619 | tokenizer=tokenizer, | ||
| 620 | class_subdir=args.class_image_dir, | ||
| 621 | num_class_images=args.num_class_images, | ||
| 622 | size=args.resolution, | ||
| 623 | shuffle=not args.no_tag_shuffle, | ||
| 624 | template_key=data_template, | ||
| 625 | valid_set_size=1, | ||
| 626 | train_set_pad=args.train_set_pad, | ||
| 627 | valid_set_pad=args.valid_set_pad, | ||
| 628 | seed=args.seed, | ||
| 629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | ||
| 630 | dtype=weight_dtype | ||
| 631 | ) | ||
| 632 | datamodule.setup() | ||
| 633 | |||
| 634 | optimizer = optimizer_class( | ||
| 635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 636 | lr=args.ti_learning_rate, | ||
| 637 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 638 | weight_decay=0.0, | ||
| 639 | eps=args.adam_epsilon, | ||
| 640 | ) | ||
| 641 | |||
| 642 | lr_scheduler = get_scheduler( | ||
| 643 | "one_cycle", | ||
| 644 | optimizer=optimizer, | ||
| 645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 646 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 647 | train_epochs=args.ti_num_train_epochs, | ||
| 648 | ) | ||
| 649 | |||
| 650 | trainer( | ||
| 651 | callbacks_fn=textual_inversion_strategy, | ||
| 652 | project="textual_inversion", | ||
| 653 | train_dataloader=datamodule.train_dataloader, | ||
| 654 | val_dataloader=datamodule.val_dataloader, | ||
| 655 | seed=args.seed, | ||
| 656 | optimizer=optimizer, | ||
| 657 | lr_scheduler=lr_scheduler, | ||
| 658 | num_train_epochs=args.ti_num_train_epochs, | ||
| 659 | sample_frequency=args.ti_num_train_epochs // 5, | ||
| 660 | checkpoint_frequency=9999999, | ||
| 661 | # -- | ||
| 662 | tokenizer=tokenizer, | ||
| 663 | sample_scheduler=sample_scheduler, | ||
| 664 | output_dir=cur_subdir, | ||
| 665 | placeholder_tokens=[placeholder_token], | ||
| 666 | placeholder_token_ids=placeholder_token_ids, | ||
| 667 | learning_rate=args.ti_learning_rate, | ||
| 668 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 669 | use_emb_decay=True, | ||
| 670 | sample_batch_size=args.sample_batch_size, | ||
| 671 | sample_num_batches=args.sample_batches, | ||
| 672 | sample_num_steps=args.sample_steps, | ||
| 673 | sample_image_size=args.sample_image_size, | ||
| 674 | ) | ||
| 675 | |||
| 676 | embeddings.persist() | ||
| 677 | |||
| 678 | # Dreambooth | ||
| 679 | |||
| 680 | print("Phase 2: Dreambooth") | ||
| 681 | |||
| 682 | cur_dir = output_dir.joinpath("2-db") | ||
| 683 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
| 684 | |||
| 685 | args.seed = seed_generator.seed() | ||
| 686 | |||
| 687 | datamodule = VlpnDataModule( | 555 | datamodule = VlpnDataModule( |
| 688 | data_file=args.train_data_file, | 556 | data_file=args.train_data_file, |
| 689 | batch_size=args.train_batch_size, | 557 | batch_size=args.train_batch_size, |
| @@ -746,12 +614,13 @@ def main(): | |||
| 746 | seed=args.seed, | 614 | seed=args.seed, |
| 747 | optimizer=optimizer, | 615 | optimizer=optimizer, |
| 748 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
| 617 | prepare_unet=True, | ||
| 749 | num_train_epochs=args.num_train_epochs, | 618 | num_train_epochs=args.num_train_epochs, |
| 750 | sample_frequency=args.sample_frequency, | 619 | sample_frequency=args.sample_frequency, |
| 751 | # -- | 620 | # -- |
| 752 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
| 753 | sample_scheduler=sample_scheduler, | 622 | sample_scheduler=sample_scheduler, |
| 754 | output_dir=cur_dir, | 623 | output_dir=output_dir, |
| 755 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 624 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
| 756 | max_grad_norm=args.max_grad_norm, | 625 | max_grad_norm=args.max_grad_norm, |
| 757 | use_ema=args.use_ema, | 626 | use_ema=args.use_ema, |
