diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py index c1c0eed..48858cc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -17,7 +17,6 @@ import transformers | |||
| 17 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
| 19 | from training.functional import train, add_placeholder_tokens, get_models | 19 | from training.functional import train, add_placeholder_tokens, get_models |
| 20 | from training.lr import plot_metrics | ||
| 21 | from training.strategy.ti import textual_inversion_strategy | 20 | from training.strategy.ti import textual_inversion_strategy |
| 22 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 23 | from training.util import save_args | 22 | from training.util import save_args |
| @@ -511,12 +510,12 @@ def parse_args(): | |||
| 511 | if isinstance(args.initializer_tokens, str): | 510 | if isinstance(args.initializer_tokens, str): |
| 512 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 511 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 513 | 512 | ||
| 514 | if len(args.initializer_tokens) == 0: | ||
| 515 | raise ValueError("You must specify --initializer_tokens") | ||
| 516 | |||
| 517 | if len(args.placeholder_tokens) == 0: | 513 | if len(args.placeholder_tokens) == 0: |
| 518 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 514 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 519 | 515 | ||
| 516 | if len(args.initializer_tokens) == 0: | ||
| 517 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 518 | |||
| 520 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 519 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 521 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 520 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 522 | 521 | ||
| @@ -856,7 +855,7 @@ def main(): | |||
| 856 | mid_point=args.lr_mid_point, | 855 | mid_point=args.lr_mid_point, |
| 857 | ) | 856 | ) |
| 858 | 857 | ||
| 859 | metrics = trainer( | 858 | trainer( |
| 860 | project="textual_inversion", | 859 | project="textual_inversion", |
| 861 | train_dataloader=datamodule.train_dataloader, | 860 | train_dataloader=datamodule.train_dataloader, |
| 862 | val_dataloader=datamodule.val_dataloader, | 861 | val_dataloader=datamodule.val_dataloader, |
| @@ -864,14 +863,13 @@ def main(): | |||
| 864 | lr_scheduler=lr_scheduler, | 863 | lr_scheduler=lr_scheduler, |
| 865 | num_train_epochs=num_train_epochs, | 864 | num_train_epochs=num_train_epochs, |
| 866 | # -- | 865 | # -- |
| 866 | group_labels=["emb"], | ||
| 867 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
| 868 | sample_frequency=sample_frequency, | 868 | sample_frequency=sample_frequency, |
| 869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
| 870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
| 871 | ) | 871 | ) |
| 872 | 872 | ||
| 873 | plot_metrics(metrics, metrics_output_file) | ||
| 874 | |||
| 875 | if not args.sequential: | 873 | if not args.sequential: |
| 876 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 874 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 877 | else: | 875 | else: |
