diff options
-rw-r--r-- | train_lora.py | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index e4b5546..d8a4880 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -323,6 +323,17 @@ def parse_args(): | |||
323 | help="Initial learning rate (after the potential warmup period) to use.", | 323 | help="Initial learning rate (after the potential warmup period) to use.", |
324 | ) | 324 | ) |
325 | parser.add_argument( | 325 | parser.add_argument( |
326 | "--learning_rate_emb", | ||
327 | type=float, | ||
328 | default=1e-5, | ||
329 | help="Initial learning rate (after the potential warmup period) to use.", | ||
330 | ) | ||
331 | parser.add_argument( | ||
332 | "--train_emb", | ||
333 | action="store_true", | ||
334 | help="Keep training text embeddings.", | ||
335 | ) | ||
336 | parser.add_argument( | ||
326 | "--scale_lr", | 337 | "--scale_lr", |
327 | action="store_true", | 338 | action="store_true", |
328 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 339 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -731,11 +742,16 @@ def main(): | |||
731 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * | 742 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * |
732 | args.pti_batch_size * accelerator.num_processes | 743 | args.pti_batch_size * accelerator.num_processes |
733 | ) | 744 | ) |
745 | args.learning_rate_emb = ( | ||
746 | args.learning_rate_emb * args.pti_gradient_accumulation_steps * | ||
747 | args.pti_batch_size * accelerator.num_processes | ||
748 | ) | ||
734 | 749 | ||
735 | if args.find_lr: | 750 | if args.find_lr: |
736 | args.learning_rate_unet = 1e-6 | 751 | args.learning_rate_unet = 1e-6 |
737 | args.learning_rate_text = 1e-6 | 752 | args.learning_rate_text = 1e-6 |
738 | args.learning_rate_pti = 1e-6 | 753 | args.learning_rate_pti = 1e-6 |
754 | args.learning_rate_emb = 1e-6 | ||
739 | args.lr_scheduler = "exponential_growth" | 755 | args.lr_scheduler = "exponential_growth" |
740 | 756 | ||
741 | if args.optimizer == 'adam8bit': | 757 | if args.optimizer == 'adam8bit': |
@@ -794,6 +810,9 @@ def main(): | |||
794 | args.lr_scheduler = "adafactor" | 810 | args.lr_scheduler = "adafactor" |
795 | args.lr_min_lr = args.learning_rate_unet | 811 | args.lr_min_lr = args.learning_rate_unet |
796 | args.learning_rate_unet = None | 812 | args.learning_rate_unet = None |
813 | args.learning_rate_text = None | ||
814 | args.learning_rate_pti = None | ||
815 | args.learning_rate_emb = None | ||
797 | elif args.optimizer == 'dadam': | 816 | elif args.optimizer == 'dadam': |
798 | try: | 817 | try: |
799 | import dadaptation | 818 | import dadaptation |
@@ -811,6 +830,8 @@ def main(): | |||
811 | 830 | ||
812 | args.learning_rate_unet = 1.0 | 831 | args.learning_rate_unet = 1.0 |
813 | args.learning_rate_text = 1.0 | 832 | args.learning_rate_text = 1.0 |
833 | args.learning_rate_pti = 1.0 | ||
834 | args.learning_rate_emb = 1.0 | ||
814 | elif args.optimizer == 'dadan': | 835 | elif args.optimizer == 'dadan': |
815 | try: | 836 | try: |
816 | import dadaptation | 837 | import dadaptation |
@@ -826,6 +847,8 @@ def main(): | |||
826 | 847 | ||
827 | args.learning_rate_unet = 1.0 | 848 | args.learning_rate_unet = 1.0 |
828 | args.learning_rate_text = 1.0 | 849 | args.learning_rate_text = 1.0 |
850 | args.learning_rate_pti = 1.0 | ||
851 | args.learning_rate_emb = 1.0 | ||
829 | else: | 852 | else: |
830 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 853 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
831 | 854 | ||
@@ -949,7 +972,8 @@ def main(): | |||
949 | sample_frequency=pti_sample_frequency, | 972 | sample_frequency=pti_sample_frequency, |
950 | ) | 973 | ) |
951 | 974 | ||
952 | embeddings.persist() | 975 | if not args.train_emb: |
976 | embeddings.persist() | ||
953 | 977 | ||
954 | # LORA | 978 | # LORA |
955 | # -------------------------------------------------------------------------------- | 979 | # -------------------------------------------------------------------------------- |
@@ -974,13 +998,13 @@ def main(): | |||
974 | 998 | ||
975 | params_to_optimize = [] | 999 | params_to_optimize = [] |
976 | group_labels = [] | 1000 | group_labels = [] |
977 | # if len(args.placeholder_tokens) != 0: | 1001 | if len(args.placeholder_tokens) != 0 and args.train_emb: |
978 | # params_to_optimize.append({ | 1002 | params_to_optimize.append({ |
979 | # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 1003 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
980 | # "lr": args.learning_rate_text, | 1004 | "lr": args.learning_rate_emb, |
981 | # "weight_decay": 0, | 1005 | "weight_decay": 0, |
982 | # }) | 1006 | }) |
983 | # group_labels.append("emb") | 1007 | group_labels.append("emb") |
984 | params_to_optimize += [ | 1008 | params_to_optimize += [ |
985 | { | 1009 | { |
986 | "params": ( | 1010 | "params": ( |