diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py index c1c0eed..48858cc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -17,7 +17,6 @@ import transformers | |||
17 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
19 | from training.functional import train, add_placeholder_tokens, get_models | 19 | from training.functional import train, add_placeholder_tokens, get_models |
20 | from training.lr import plot_metrics | ||
21 | from training.strategy.ti import textual_inversion_strategy | 20 | from training.strategy.ti import textual_inversion_strategy |
22 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
23 | from training.util import save_args | 22 | from training.util import save_args |
@@ -511,12 +510,12 @@ def parse_args(): | |||
511 | if isinstance(args.initializer_tokens, str): | 510 | if isinstance(args.initializer_tokens, str): |
512 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 511 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
513 | 512 | ||
514 | if len(args.initializer_tokens) == 0: | ||
515 | raise ValueError("You must specify --initializer_tokens") | ||
516 | |||
517 | if len(args.placeholder_tokens) == 0: | 513 | if len(args.placeholder_tokens) == 0: |
518 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 514 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
519 | 515 | ||
516 | if len(args.initializer_tokens) == 0: | ||
517 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
518 | |||
520 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 519 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
521 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 520 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
522 | 521 | ||
@@ -856,7 +855,7 @@ def main(): | |||
856 | mid_point=args.lr_mid_point, | 855 | mid_point=args.lr_mid_point, |
857 | ) | 856 | ) |
858 | 857 | ||
859 | metrics = trainer( | 858 | trainer( |
860 | project="textual_inversion", | 859 | project="textual_inversion", |
861 | train_dataloader=datamodule.train_dataloader, | 860 | train_dataloader=datamodule.train_dataloader, |
862 | val_dataloader=datamodule.val_dataloader, | 861 | val_dataloader=datamodule.val_dataloader, |
@@ -864,14 +863,13 @@ def main(): | |||
864 | lr_scheduler=lr_scheduler, | 863 | lr_scheduler=lr_scheduler, |
865 | num_train_epochs=num_train_epochs, | 864 | num_train_epochs=num_train_epochs, |
866 | # -- | 865 | # -- |
866 | group_labels=["emb"], | ||
867 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
868 | sample_frequency=sample_frequency, | 868 | sample_frequency=sample_frequency, |
869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
871 | ) | 871 | ) |
872 | 872 | ||
873 | plot_metrics(metrics, metrics_output_file) | ||
874 | |||
875 | if not args.sequential: | 873 | if not args.sequential: |
876 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 874 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
877 | else: | 875 | else: |