summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py75
1 files changed, 41 insertions, 34 deletions
diff --git a/train_lora.py b/train_lora.py
index 9f17495..1626be6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -19,7 +19,6 @@ import transformers
19from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
22from training.lr import plot_metrics
23from training.strategy.lora import lora_strategy 22from training.strategy.lora import lora_strategy
24from training.optimization import get_scheduler 23from training.optimization import get_scheduler
25from training.util import save_args 24from training.util import save_args
@@ -568,6 +567,9 @@ def parse_args():
568 if len(args.placeholder_tokens) == 0: 567 if len(args.placeholder_tokens) == 0:
569 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 568 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
570 569
570 if len(args.initializer_tokens) == 0:
571 args.initializer_tokens = args.placeholder_tokens.copy()
572
571 if len(args.placeholder_tokens) != len(args.initializer_tokens): 573 if len(args.placeholder_tokens) != len(args.initializer_tokens):
572 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 574 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
573 575
@@ -918,7 +920,7 @@ def main():
918 train_epochs=num_pti_epochs, 920 train_epochs=num_pti_epochs,
919 ) 921 )
920 922
921 metrics = trainer( 923 trainer(
922 strategy=lora_strategy, 924 strategy=lora_strategy,
923 pti_mode=True, 925 pti_mode=True,
924 project="pti", 926 project="pti",
@@ -929,12 +931,13 @@ def main():
929 num_train_epochs=num_pti_epochs, 931 num_train_epochs=num_pti_epochs,
930 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 932 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
931 # -- 933 # --
934 group_labels=["emb"],
932 sample_output_dir=pti_sample_output_dir, 935 sample_output_dir=pti_sample_output_dir,
933 checkpoint_output_dir=pti_checkpoint_output_dir, 936 checkpoint_output_dir=pti_checkpoint_output_dir,
934 sample_frequency=pti_sample_frequency, 937 sample_frequency=pti_sample_frequency,
935 ) 938 )
936 939
937 plot_metrics(metrics, pti_output_dir / "lr.png") 940 # embeddings.persist()
938 941
939 # LORA 942 # LORA
940 # -------------------------------------------------------------------------------- 943 # --------------------------------------------------------------------------------
@@ -957,34 +960,39 @@ def main():
957 ) * args.gradient_accumulation_steps 960 ) * args.gradient_accumulation_steps
958 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 961 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
959 962
960 lora_optimizer = create_optimizer( 963 params_to_optimize = []
961 [ 964 group_labels = []
962 { 965 if len(args.placeholder_tokens) != 0:
963 "params": ( 966 params_to_optimize.append({
964 param 967 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
965 for param in unet.parameters() 968 "lr": args.learning_rate_text,
966 if param.requires_grad 969 "weight_decay": 0,
967 ), 970 })
968 "lr": args.learning_rate_unet, 971 group_labels.append("emb")
969 }, 972 params_to_optimize += [
970 { 973 {
971 "params": ( 974 "params": (
972 param 975 param
973 for param in itertools.chain( 976 for param in unet.parameters()
974 text_encoder.text_model.encoder.parameters(), 977 if param.requires_grad
975 text_encoder.text_model.final_layer_norm.parameters(), 978 ),
976 ) 979 "lr": args.learning_rate_unet,
977 if param.requires_grad 980 },
978 ), 981 {
979 "lr": args.learning_rate_text, 982 "params": (
980 }, 983 param
981 { 984 for param in itertools.chain(
982 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 985 text_encoder.text_model.encoder.parameters(),
983 "lr": args.learning_rate_text, 986 text_encoder.text_model.final_layer_norm.parameters(),
984 "weight_decay": 0, 987 )
985 }, 988 if param.requires_grad
986 ] 989 ),
987 ) 990 "lr": args.learning_rate_text,
991 },
992 ]
993 group_labels += ["unet", "text"]
994
995 lora_optimizer = create_optimizer(params_to_optimize)
988 996
989 lora_lr_scheduler = create_lr_scheduler( 997 lora_lr_scheduler = create_lr_scheduler(
990 gradient_accumulation_steps=args.gradient_accumulation_steps, 998 gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -993,7 +1001,7 @@ def main():
993 train_epochs=num_train_epochs, 1001 train_epochs=num_train_epochs,
994 ) 1002 )
995 1003
996 metrics = trainer( 1004 trainer(
997 strategy=lora_strategy, 1005 strategy=lora_strategy,
998 project="lora", 1006 project="lora",
999 train_dataloader=lora_datamodule.train_dataloader, 1007 train_dataloader=lora_datamodule.train_dataloader,
@@ -1003,13 +1011,12 @@ def main():
1003 num_train_epochs=num_train_epochs, 1011 num_train_epochs=num_train_epochs,
1004 gradient_accumulation_steps=args.gradient_accumulation_steps, 1012 gradient_accumulation_steps=args.gradient_accumulation_steps,
1005 # -- 1013 # --
1014 group_labels=group_labels,
1006 sample_output_dir=lora_sample_output_dir, 1015 sample_output_dir=lora_sample_output_dir,
1007 checkpoint_output_dir=lora_checkpoint_output_dir, 1016 checkpoint_output_dir=lora_checkpoint_output_dir,
1008 sample_frequency=lora_sample_frequency, 1017 sample_frequency=lora_sample_frequency,
1009 ) 1018 )
1010 1019
1011 plot_metrics(metrics, lora_output_dir / "lr.png")
1012
1013 1020
1014if __name__ == "__main__": 1021if __name__ == "__main__":
1015 main() 1022 main()