summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py40
1 files changed, 32 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py
index e4b5546..d8a4880 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -323,6 +323,17 @@ def parse_args():
323 help="Initial learning rate (after the potential warmup period) to use.", 323 help="Initial learning rate (after the potential warmup period) to use.",
324 ) 324 )
325 parser.add_argument( 325 parser.add_argument(
326 "--learning_rate_emb",
327 type=float,
328 default=1e-5,
329 help="Initial learning rate (after the potential warmup period) to use.",
330 )
331 parser.add_argument(
332 "--train_emb",
333 action="store_true",
334 help="Keep training text embeddings.",
335 )
336 parser.add_argument(
326 "--scale_lr", 337 "--scale_lr",
327 action="store_true", 338 action="store_true",
328 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 339 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -731,11 +742,16 @@ def main():
731 args.learning_rate_pti * args.pti_gradient_accumulation_steps * 742 args.learning_rate_pti * args.pti_gradient_accumulation_steps *
732 args.pti_batch_size * accelerator.num_processes 743 args.pti_batch_size * accelerator.num_processes
733 ) 744 )
745 args.learning_rate_emb = (
746 args.learning_rate_emb * args.pti_gradient_accumulation_steps *
747 args.pti_batch_size * accelerator.num_processes
748 )
734 749
735 if args.find_lr: 750 if args.find_lr:
736 args.learning_rate_unet = 1e-6 751 args.learning_rate_unet = 1e-6
737 args.learning_rate_text = 1e-6 752 args.learning_rate_text = 1e-6
738 args.learning_rate_pti = 1e-6 753 args.learning_rate_pti = 1e-6
754 args.learning_rate_emb = 1e-6
739 args.lr_scheduler = "exponential_growth" 755 args.lr_scheduler = "exponential_growth"
740 756
741 if args.optimizer == 'adam8bit': 757 if args.optimizer == 'adam8bit':
@@ -794,6 +810,9 @@ def main():
794 args.lr_scheduler = "adafactor" 810 args.lr_scheduler = "adafactor"
795 args.lr_min_lr = args.learning_rate_unet 811 args.lr_min_lr = args.learning_rate_unet
796 args.learning_rate_unet = None 812 args.learning_rate_unet = None
813 args.learning_rate_text = None
814 args.learning_rate_pti = None
815 args.learning_rate_emb = None
797 elif args.optimizer == 'dadam': 816 elif args.optimizer == 'dadam':
798 try: 817 try:
799 import dadaptation 818 import dadaptation
@@ -811,6 +830,8 @@ def main():
811 830
812 args.learning_rate_unet = 1.0 831 args.learning_rate_unet = 1.0
813 args.learning_rate_text = 1.0 832 args.learning_rate_text = 1.0
833 args.learning_rate_pti = 1.0
834 args.learning_rate_emb = 1.0
814 elif args.optimizer == 'dadan': 835 elif args.optimizer == 'dadan':
815 try: 836 try:
816 import dadaptation 837 import dadaptation
@@ -826,6 +847,8 @@ def main():
826 847
827 args.learning_rate_unet = 1.0 848 args.learning_rate_unet = 1.0
828 args.learning_rate_text = 1.0 849 args.learning_rate_text = 1.0
850 args.learning_rate_pti = 1.0
851 args.learning_rate_emb = 1.0
829 else: 852 else:
830 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 853 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
831 854
@@ -949,7 +972,8 @@ def main():
949 sample_frequency=pti_sample_frequency, 972 sample_frequency=pti_sample_frequency,
950 ) 973 )
951 974
952 embeddings.persist() 975 if not args.train_emb:
976 embeddings.persist()
953 977
954 # LORA 978 # LORA
955 # -------------------------------------------------------------------------------- 979 # --------------------------------------------------------------------------------
@@ -974,13 +998,13 @@ def main():
974 998
975 params_to_optimize = [] 999 params_to_optimize = []
976 group_labels = [] 1000 group_labels = []
977 # if len(args.placeholder_tokens) != 0: 1001 if len(args.placeholder_tokens) != 0 and args.train_emb:
978 # params_to_optimize.append({ 1002 params_to_optimize.append({
979 # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), 1003 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
980 # "lr": args.learning_rate_text, 1004 "lr": args.learning_rate_emb,
981 # "weight_decay": 0, 1005 "weight_decay": 0,
982 # }) 1006 })
983 # group_labels.append("emb") 1007 group_labels.append("emb")
984 params_to_optimize += [ 1008 params_to_optimize += [
985 { 1009 {
986 "params": ( 1010 "params": (