diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 75 |
1 files changed, 41 insertions, 34 deletions
diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -19,7 +19,6 @@ import transformers | |||
| 19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 22 | from training.lr import plot_metrics | ||
| 23 | from training.strategy.lora import lora_strategy | 22 | from training.strategy.lora import lora_strategy |
| 24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
| 25 | from training.util import save_args | 24 | from training.util import save_args |
| @@ -568,6 +567,9 @@ def parse_args(): | |||
| 568 | if len(args.placeholder_tokens) == 0: | 567 | if len(args.placeholder_tokens) == 0: |
| 569 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 568 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 570 | 569 | ||
| 570 | if len(args.initializer_tokens) == 0: | ||
| 571 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 572 | |||
| 571 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 573 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 572 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 574 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 573 | 575 | ||
| @@ -918,7 +920,7 @@ def main(): | |||
| 918 | train_epochs=num_pti_epochs, | 920 | train_epochs=num_pti_epochs, |
| 919 | ) | 921 | ) |
| 920 | 922 | ||
| 921 | metrics = trainer( | 923 | trainer( |
| 922 | strategy=lora_strategy, | 924 | strategy=lora_strategy, |
| 923 | pti_mode=True, | 925 | pti_mode=True, |
| 924 | project="pti", | 926 | project="pti", |
| @@ -929,12 +931,13 @@ def main(): | |||
| 929 | num_train_epochs=num_pti_epochs, | 931 | num_train_epochs=num_pti_epochs, |
| 930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 932 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 931 | # -- | 933 | # -- |
| 934 | group_labels=["emb"], | ||
| 932 | sample_output_dir=pti_sample_output_dir, | 935 | sample_output_dir=pti_sample_output_dir, |
| 933 | checkpoint_output_dir=pti_checkpoint_output_dir, | 936 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 934 | sample_frequency=pti_sample_frequency, | 937 | sample_frequency=pti_sample_frequency, |
| 935 | ) | 938 | ) |
| 936 | 939 | ||
| 937 | plot_metrics(metrics, pti_output_dir / "lr.png") | 940 | # embeddings.persist() |
| 938 | 941 | ||
| 939 | # LORA | 942 | # LORA |
| 940 | # -------------------------------------------------------------------------------- | 943 | # -------------------------------------------------------------------------------- |
| @@ -957,34 +960,39 @@ def main(): | |||
| 957 | ) * args.gradient_accumulation_steps | 960 | ) * args.gradient_accumulation_steps |
| 958 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 961 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
| 959 | 962 | ||
| 960 | lora_optimizer = create_optimizer( | 963 | params_to_optimize = [] |
| 961 | [ | 964 | group_labels = [] |
| 962 | { | 965 | if len(args.placeholder_tokens) != 0: |
| 963 | "params": ( | 966 | params_to_optimize.append({ |
| 964 | param | 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 965 | for param in unet.parameters() | 968 | "lr": args.learning_rate_text, |
| 966 | if param.requires_grad | 969 | "weight_decay": 0, |
| 967 | ), | 970 | }) |
| 968 | "lr": args.learning_rate_unet, | 971 | group_labels.append("emb") |
| 969 | }, | 972 | params_to_optimize += [ |
| 970 | { | 973 | { |
| 971 | "params": ( | 974 | "params": ( |
| 972 | param | 975 | param |
| 973 | for param in itertools.chain( | 976 | for param in unet.parameters() |
| 974 | text_encoder.text_model.encoder.parameters(), | 977 | if param.requires_grad |
| 975 | text_encoder.text_model.final_layer_norm.parameters(), | 978 | ), |
| 976 | ) | 979 | "lr": args.learning_rate_unet, |
| 977 | if param.requires_grad | 980 | }, |
| 978 | ), | 981 | { |
| 979 | "lr": args.learning_rate_text, | 982 | "params": ( |
| 980 | }, | 983 | param |
| 981 | { | 984 | for param in itertools.chain( |
| 982 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 985 | text_encoder.text_model.encoder.parameters(), |
| 983 | "lr": args.learning_rate_text, | 986 | text_encoder.text_model.final_layer_norm.parameters(), |
| 984 | "weight_decay": 0, | 987 | ) |
| 985 | }, | 988 | if param.requires_grad |
| 986 | ] | 989 | ), |
| 987 | ) | 990 | "lr": args.learning_rate_text, |
| 991 | }, | ||
| 992 | ] | ||
| 993 | group_labels += ["unet", "text"] | ||
| 994 | |||
| 995 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 988 | 996 | ||
| 989 | lora_lr_scheduler = create_lr_scheduler( | 997 | lora_lr_scheduler = create_lr_scheduler( |
| 990 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 998 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| @@ -993,7 +1001,7 @@ def main(): | |||
| 993 | train_epochs=num_train_epochs, | 1001 | train_epochs=num_train_epochs, |
| 994 | ) | 1002 | ) |
| 995 | 1003 | ||
| 996 | metrics = trainer( | 1004 | trainer( |
| 997 | strategy=lora_strategy, | 1005 | strategy=lora_strategy, |
| 998 | project="lora", | 1006 | project="lora", |
| 999 | train_dataloader=lora_datamodule.train_dataloader, | 1007 | train_dataloader=lora_datamodule.train_dataloader, |
| @@ -1003,13 +1011,12 @@ def main(): | |||
| 1003 | num_train_epochs=num_train_epochs, | 1011 | num_train_epochs=num_train_epochs, |
| 1004 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1012 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1005 | # -- | 1013 | # -- |
| 1014 | group_labels=group_labels, | ||
| 1006 | sample_output_dir=lora_sample_output_dir, | 1015 | sample_output_dir=lora_sample_output_dir, |
| 1007 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1016 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1008 | sample_frequency=lora_sample_frequency, | 1017 | sample_frequency=lora_sample_frequency, |
| 1009 | ) | 1018 | ) |
| 1010 | 1019 | ||
| 1011 | plot_metrics(metrics, lora_output_dir / "lr.png") | ||
| 1012 | |||
| 1013 | 1020 | ||
| 1014 | if __name__ == "__main__": | 1021 | if __name__ == "__main__": |
| 1015 | main() | 1022 | main() |
