diff options
author | Volpeon <git@volpeon.ink> | 2023-04-10 10:34:12 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-10 10:34:12 +0200 |
commit | eb6a92abda5893c975437026cdaf0ce0bfefe2a4 (patch) | |
tree | a1525010b48362986e0cc2b7c3f7505a35dea71a /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.tar.gz textual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.tar.bz2 textual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 68 |
1 files changed, 41 insertions, 27 deletions
diff --git a/train_ti.py b/train_ti.py index ebac302..eb08bda 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -152,6 +152,11 @@ def parse_args(): | |||
152 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 152 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
153 | ) | 153 | ) |
154 | parser.add_argument( | 154 | parser.add_argument( |
155 | "--train_dir_embeddings", | ||
156 | action="store_true", | ||
157 | help="Train embeddings loaded from embeddings directory.", | ||
158 | ) | ||
159 | parser.add_argument( | ||
155 | "--collection", | 160 | "--collection", |
156 | type=str, | 161 | type=str, |
157 | nargs='*', | 162 | nargs='*', |
@@ -404,6 +409,12 @@ def parse_args(): | |||
404 | help="If checkpoints are saved on maximum accuracy", | 409 | help="If checkpoints are saved on maximum accuracy", |
405 | ) | 410 | ) |
406 | parser.add_argument( | 411 | parser.add_argument( |
412 | "--sample_num", | ||
413 | type=int, | ||
414 | default=None, | ||
415 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
416 | ) | ||
417 | parser.add_argument( | ||
407 | "--sample_frequency", | 418 | "--sample_frequency", |
408 | type=int, | 419 | type=int, |
409 | default=1, | 420 | default=1, |
@@ -669,9 +680,14 @@ def main(): | |||
669 | raise ValueError("--embeddings_dir must point to an existing directory") | 680 | raise ValueError("--embeddings_dir must point to an existing directory") |
670 | 681 | ||
671 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 682 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
672 | embeddings.persist() | ||
673 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 683 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
674 | 684 | ||
685 | if args.train_dir_embeddings: | ||
686 | args.placeholder_tokens = added_tokens | ||
687 | print("Training embeddings from embeddings dir") | ||
688 | else: | ||
689 | embeddings.persist() | ||
690 | |||
675 | if args.scale_lr: | 691 | if args.scale_lr: |
676 | args.learning_rate = ( | 692 | args.learning_rate = ( |
677 | args.learning_rate * args.gradient_accumulation_steps * | 693 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -852,28 +868,8 @@ def main(): | |||
852 | args.num_train_steps / len(datamodule.train_dataset) | 868 | args.num_train_steps / len(datamodule.train_dataset) |
853 | ) * args.gradient_accumulation_steps | 869 | ) * args.gradient_accumulation_steps |
854 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 870 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
855 | 871 | if args.sample_num is not None: | |
856 | optimizer = create_optimizer( | 872 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
857 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
858 | lr=args.learning_rate, | ||
859 | ) | ||
860 | |||
861 | lr_scheduler = get_scheduler( | ||
862 | args.lr_scheduler, | ||
863 | optimizer=optimizer, | ||
864 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
865 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
866 | min_lr=args.lr_min_lr, | ||
867 | warmup_func=args.lr_warmup_func, | ||
868 | annealing_func=args.lr_annealing_func, | ||
869 | warmup_exp=args.lr_warmup_exp, | ||
870 | annealing_exp=args.lr_annealing_exp, | ||
871 | cycles=args.lr_cycles, | ||
872 | end_lr=1e3, | ||
873 | train_epochs=num_train_epochs, | ||
874 | warmup_epochs=args.lr_warmup_epochs, | ||
875 | mid_point=args.lr_mid_point, | ||
876 | ) | ||
877 | 873 | ||
878 | training_iter = 0 | 874 | training_iter = 0 |
879 | 875 | ||
@@ -888,6 +884,28 @@ def main(): | |||
888 | print(f"------------ TI cycle {training_iter} ------------") | 884 | print(f"------------ TI cycle {training_iter} ------------") |
889 | print("") | 885 | print("") |
890 | 886 | ||
887 | optimizer = create_optimizer( | ||
888 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
889 | lr=args.learning_rate, | ||
890 | ) | ||
891 | |||
892 | lr_scheduler = get_scheduler( | ||
893 | args.lr_scheduler, | ||
894 | optimizer=optimizer, | ||
895 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
896 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
897 | min_lr=args.lr_min_lr, | ||
898 | warmup_func=args.lr_warmup_func, | ||
899 | annealing_func=args.lr_annealing_func, | ||
900 | warmup_exp=args.lr_warmup_exp, | ||
901 | annealing_exp=args.lr_annealing_exp, | ||
902 | cycles=args.lr_cycles, | ||
903 | end_lr=1e3, | ||
904 | train_epochs=num_train_epochs, | ||
905 | warmup_epochs=args.lr_warmup_epochs, | ||
906 | mid_point=args.lr_mid_point, | ||
907 | ) | ||
908 | |||
891 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 909 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" |
892 | sample_output_dir = output_dir / project / "samples" | 910 | sample_output_dir = output_dir / project / "samples" |
893 | checkpoint_output_dir = output_dir / project / "checkpoints" | 911 | checkpoint_output_dir = output_dir / project / "checkpoints" |
@@ -908,10 +926,6 @@ def main(): | |||
908 | placeholder_token_ids=placeholder_token_ids, | 926 | placeholder_token_ids=placeholder_token_ids, |
909 | ) | 927 | ) |
910 | 928 | ||
911 | response = input("Run another cycle? [y/n] ") | ||
912 | continue_training = response.lower().strip() != "n" | ||
913 | training_iter += 1 | ||
914 | |||
915 | if not args.sequential: | 929 | if not args.sequential: |
916 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 930 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
917 | else: | 931 | else: |