diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-09 11:29:31 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-09 11:29:31 +0200 |
| commit | ba9fd1a10746d85d2502c8a79ac49db63d346b04 (patch) | |
| tree | 568bf65a0a4dcea2c34de4006b5761d0d6564307 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.gz textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.bz2 textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 66 |
1 files changed, 42 insertions, 24 deletions
diff --git a/train_ti.py b/train_ti.py index daf8bc5..2d51800 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -458,6 +458,12 @@ def parse_args(): | |||
| 458 | help="The weight of prior preservation loss." | 458 | help="The weight of prior preservation loss." |
| 459 | ) | 459 | ) |
| 460 | parser.add_argument( | 460 | parser.add_argument( |
| 461 | "--emb_dropout", | ||
| 462 | type=float, | ||
| 463 | default=0, | ||
| 464 | help="Embedding dropout probability.", | ||
| 465 | ) | ||
| 466 | parser.add_argument( | ||
| 461 | "--use_emb_decay", | 467 | "--use_emb_decay", |
| 462 | action="store_true", | 468 | action="store_true", |
| 463 | help="Whether to use embedding decay." | 469 | help="Whether to use embedding decay." |
| @@ -624,7 +630,7 @@ def main(): | |||
| 624 | save_args(output_dir, args) | 630 | save_args(output_dir, args) |
| 625 | 631 | ||
| 626 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 632 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 627 | args.pretrained_model_name_or_path) | 633 | args.pretrained_model_name_or_path, args.emb_dropout) |
| 628 | 634 | ||
| 629 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 635 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 630 | tokenizer.set_dropout(args.vector_dropout) | 636 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -755,8 +761,6 @@ def main(): | |||
| 755 | else: | 761 | else: |
| 756 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 762 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 757 | 763 | ||
| 758 | checkpoint_output_dir = output_dir / "checkpoints" | ||
| 759 | |||
| 760 | trainer = partial( | 764 | trainer = partial( |
| 761 | train, | 765 | train, |
| 762 | accelerator=accelerator, | 766 | accelerator=accelerator, |
| @@ -777,7 +781,6 @@ def main(): | |||
| 777 | global_step_offset=global_step_offset, | 781 | global_step_offset=global_step_offset, |
| 778 | offset_noise_strength=args.offset_noise_strength, | 782 | offset_noise_strength=args.offset_noise_strength, |
| 779 | # -- | 783 | # -- |
| 780 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 781 | use_emb_decay=args.use_emb_decay, | 784 | use_emb_decay=args.use_emb_decay, |
| 782 | emb_decay_target=args.emb_decay_target, | 785 | emb_decay_target=args.emb_decay_target, |
| 783 | emb_decay=args.emb_decay, | 786 | emb_decay=args.emb_decay, |
| @@ -793,11 +796,6 @@ def main(): | |||
| 793 | ) | 796 | ) |
| 794 | 797 | ||
| 795 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 798 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 796 | if len(placeholder_tokens) == 1: | ||
| 797 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" | ||
| 798 | else: | ||
| 799 | sample_output_dir = output_dir / "samples" | ||
| 800 | |||
| 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 799 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 802 | tokenizer=tokenizer, | 800 | tokenizer=tokenizer, |
| 803 | embeddings=embeddings, | 801 | embeddings=embeddings, |
| @@ -809,7 +807,11 @@ def main(): | |||
| 809 | 807 | ||
| 810 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 808 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
| 811 | 809 | ||
| 812 | print(f"{i + 1}: {stats}") | 810 | print("") |
| 811 | print(f"============ TI batch {i + 1} ============") | ||
| 812 | print("") | ||
| 813 | print(stats) | ||
| 814 | print("") | ||
| 813 | 815 | ||
| 814 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 816 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
| 815 | 817 | ||
| @@ -868,20 +870,36 @@ def main(): | |||
| 868 | mid_point=args.lr_mid_point, | 870 | mid_point=args.lr_mid_point, |
| 869 | ) | 871 | ) |
| 870 | 872 | ||
| 871 | trainer( | 873 | continue_training = True |
| 872 | project="textual_inversion", | 874 | training_iter = 1 |
| 873 | train_dataloader=datamodule.train_dataloader, | 875 | |
| 874 | val_dataloader=datamodule.val_dataloader, | 876 | while continue_training: |
| 875 | optimizer=optimizer, | 877 | print(f"------------ TI cycle {training_iter} ------------") |
| 876 | lr_scheduler=lr_scheduler, | 878 | print("") |
| 877 | num_train_epochs=num_train_epochs, | 879 | |
| 878 | # -- | 880 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" |
| 879 | group_labels=["emb"], | 881 | sample_output_dir = output_dir / project / "samples" |
| 880 | sample_output_dir=sample_output_dir, | 882 | checkpoint_output_dir = output_dir / project / "checkpoints" |
| 881 | sample_frequency=sample_frequency, | 883 | |
| 882 | placeholder_tokens=placeholder_tokens, | 884 | trainer( |
| 883 | placeholder_token_ids=placeholder_token_ids, | 885 | project=project, |
| 884 | ) | 886 | train_dataloader=datamodule.train_dataloader, |
| 887 | val_dataloader=datamodule.val_dataloader, | ||
| 888 | optimizer=optimizer, | ||
| 889 | lr_scheduler=lr_scheduler, | ||
| 890 | num_train_epochs=num_train_epochs, | ||
| 891 | # -- | ||
| 892 | group_labels=["emb"], | ||
| 893 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 894 | sample_output_dir=sample_output_dir, | ||
| 895 | sample_frequency=sample_frequency, | ||
| 896 | placeholder_tokens=placeholder_tokens, | ||
| 897 | placeholder_token_ids=placeholder_token_ids, | ||
| 898 | ) | ||
| 899 | |||
| 900 | response = input("Run another cycle? [y/n] ") | ||
| 901 | continue_training = response.lower().strip() != "n" | ||
| 902 | training_iter += 1 | ||
| 885 | 903 | ||
| 886 | if not args.sequential: | 904 | if not args.sequential: |
| 887 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 905 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
