diff options
author | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-08 07:58:14 +0200 |
commit | 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch) | |
tree | b1483a52fb853aecb7b73635cded3cce61edf125 /train_lora.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2 textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 75 |
1 files changed, 41 insertions, 34 deletions
diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -19,7 +19,6 @@ import transformers | |||
19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
22 | from training.lr import plot_metrics | ||
23 | from training.strategy.lora import lora_strategy | 22 | from training.strategy.lora import lora_strategy |
24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
25 | from training.util import save_args | 24 | from training.util import save_args |
@@ -568,6 +567,9 @@ def parse_args(): | |||
568 | if len(args.placeholder_tokens) == 0: | 567 | if len(args.placeholder_tokens) == 0: |
569 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 568 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
570 | 569 | ||
570 | if len(args.initializer_tokens) == 0: | ||
571 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
572 | |||
571 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 573 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
572 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 574 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
573 | 575 | ||
@@ -918,7 +920,7 @@ def main(): | |||
918 | train_epochs=num_pti_epochs, | 920 | train_epochs=num_pti_epochs, |
919 | ) | 921 | ) |
920 | 922 | ||
921 | metrics = trainer( | 923 | trainer( |
922 | strategy=lora_strategy, | 924 | strategy=lora_strategy, |
923 | pti_mode=True, | 925 | pti_mode=True, |
924 | project="pti", | 926 | project="pti", |
@@ -929,12 +931,13 @@ def main(): | |||
929 | num_train_epochs=num_pti_epochs, | 931 | num_train_epochs=num_pti_epochs, |
930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 932 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
931 | # -- | 933 | # -- |
934 | group_labels=["emb"], | ||
932 | sample_output_dir=pti_sample_output_dir, | 935 | sample_output_dir=pti_sample_output_dir, |
933 | checkpoint_output_dir=pti_checkpoint_output_dir, | 936 | checkpoint_output_dir=pti_checkpoint_output_dir, |
934 | sample_frequency=pti_sample_frequency, | 937 | sample_frequency=pti_sample_frequency, |
935 | ) | 938 | ) |
936 | 939 | ||
937 | plot_metrics(metrics, pti_output_dir / "lr.png") | 940 | # embeddings.persist() |
938 | 941 | ||
939 | # LORA | 942 | # LORA |
940 | # -------------------------------------------------------------------------------- | 943 | # -------------------------------------------------------------------------------- |
@@ -957,34 +960,39 @@ def main(): | |||
957 | ) * args.gradient_accumulation_steps | 960 | ) * args.gradient_accumulation_steps |
958 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 961 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
959 | 962 | ||
960 | lora_optimizer = create_optimizer( | 963 | params_to_optimize = [] |
961 | [ | 964 | group_labels = [] |
962 | { | 965 | if len(args.placeholder_tokens) != 0: |
963 | "params": ( | 966 | params_to_optimize.append({ |
964 | param | 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
965 | for param in unet.parameters() | 968 | "lr": args.learning_rate_text, |
966 | if param.requires_grad | 969 | "weight_decay": 0, |
967 | ), | 970 | }) |
968 | "lr": args.learning_rate_unet, | 971 | group_labels.append("emb") |
969 | }, | 972 | params_to_optimize += [ |
970 | { | 973 | { |
971 | "params": ( | 974 | "params": ( |
972 | param | 975 | param |
973 | for param in itertools.chain( | 976 | for param in unet.parameters() |
974 | text_encoder.text_model.encoder.parameters(), | 977 | if param.requires_grad |
975 | text_encoder.text_model.final_layer_norm.parameters(), | 978 | ), |
976 | ) | 979 | "lr": args.learning_rate_unet, |
977 | if param.requires_grad | 980 | }, |
978 | ), | 981 | { |
979 | "lr": args.learning_rate_text, | 982 | "params": ( |
980 | }, | 983 | param |
981 | { | 984 | for param in itertools.chain( |
982 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 985 | text_encoder.text_model.encoder.parameters(), |
983 | "lr": args.learning_rate_text, | 986 | text_encoder.text_model.final_layer_norm.parameters(), |
984 | "weight_decay": 0, | 987 | ) |
985 | }, | 988 | if param.requires_grad |
986 | ] | 989 | ), |
987 | ) | 990 | "lr": args.learning_rate_text, |
991 | }, | ||
992 | ] | ||
993 | group_labels += ["unet", "text"] | ||
994 | |||
995 | lora_optimizer = create_optimizer(params_to_optimize) | ||
988 | 996 | ||
989 | lora_lr_scheduler = create_lr_scheduler( | 997 | lora_lr_scheduler = create_lr_scheduler( |
990 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 998 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
@@ -993,7 +1001,7 @@ def main(): | |||
993 | train_epochs=num_train_epochs, | 1001 | train_epochs=num_train_epochs, |
994 | ) | 1002 | ) |
995 | 1003 | ||
996 | metrics = trainer( | 1004 | trainer( |
997 | strategy=lora_strategy, | 1005 | strategy=lora_strategy, |
998 | project="lora", | 1006 | project="lora", |
999 | train_dataloader=lora_datamodule.train_dataloader, | 1007 | train_dataloader=lora_datamodule.train_dataloader, |
@@ -1003,13 +1011,12 @@ def main(): | |||
1003 | num_train_epochs=num_train_epochs, | 1011 | num_train_epochs=num_train_epochs, |
1004 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1012 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1005 | # -- | 1013 | # -- |
1014 | group_labels=group_labels, | ||
1006 | sample_output_dir=lora_sample_output_dir, | 1015 | sample_output_dir=lora_sample_output_dir, |
1007 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1016 | checkpoint_output_dir=lora_checkpoint_output_dir, |
1008 | sample_frequency=lora_sample_frequency, | 1017 | sample_frequency=lora_sample_frequency, |
1009 | ) | 1018 | ) |
1010 | 1019 | ||
1011 | plot_metrics(metrics, lora_output_dir / "lr.png") | ||
1012 | |||
1013 | 1020 | ||
1014 | if __name__ == "__main__": | 1021 | if __name__ == "__main__": |
1015 | main() | 1022 | main() |