summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py
index c1c0eed..48858cc 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,6 @@ import transformers
17from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, keyword_filter 18from data.csv import VlpnDataModule, keyword_filter
19from training.functional import train, add_placeholder_tokens, get_models 19from training.functional import train, add_placeholder_tokens, get_models
20from training.lr import plot_metrics
21from training.strategy.ti import textual_inversion_strategy 20from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 21from training.optimization import get_scheduler
23from training.util import save_args 22from training.util import save_args
@@ -511,12 +510,12 @@ def parse_args():
511 if isinstance(args.initializer_tokens, str): 510 if isinstance(args.initializer_tokens, str):
512 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 511 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
513 512
514 if len(args.initializer_tokens) == 0:
515 raise ValueError("You must specify --initializer_tokens")
516
517 if len(args.placeholder_tokens) == 0: 513 if len(args.placeholder_tokens) == 0:
518 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 514 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
519 515
516 if len(args.initializer_tokens) == 0:
517 args.initializer_tokens = args.placeholder_tokens.copy()
518
520 if len(args.placeholder_tokens) != len(args.initializer_tokens): 519 if len(args.placeholder_tokens) != len(args.initializer_tokens):
521 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 520 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
522 521
@@ -856,7 +855,7 @@ def main():
856 mid_point=args.lr_mid_point, 855 mid_point=args.lr_mid_point,
857 ) 856 )
858 857
859 metrics = trainer( 858 trainer(
860 project="textual_inversion", 859 project="textual_inversion",
861 train_dataloader=datamodule.train_dataloader, 860 train_dataloader=datamodule.train_dataloader,
862 val_dataloader=datamodule.val_dataloader, 861 val_dataloader=datamodule.val_dataloader,
@@ -864,14 +863,13 @@ def main():
864 lr_scheduler=lr_scheduler, 863 lr_scheduler=lr_scheduler,
865 num_train_epochs=num_train_epochs, 864 num_train_epochs=num_train_epochs,
866 # -- 865 # --
866 group_labels=["emb"],
867 sample_output_dir=sample_output_dir, 867 sample_output_dir=sample_output_dir,
868 sample_frequency=sample_frequency, 868 sample_frequency=sample_frequency,
869 placeholder_tokens=placeholder_tokens, 869 placeholder_tokens=placeholder_tokens,
870 placeholder_token_ids=placeholder_token_ids, 870 placeholder_token_ids=placeholder_token_ids,
871 ) 871 )
872 872
873 plot_metrics(metrics, metrics_output_file)
874
875 if not args.sequential: 873 if not args.sequential:
876 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 874 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
877 else: 875 else: