diff options
| -rw-r--r-- | train_lora.py | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index e4b5546..d8a4880 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -323,6 +323,17 @@ def parse_args(): | |||
| 323 | help="Initial learning rate (after the potential warmup period) to use.", | 323 | help="Initial learning rate (after the potential warmup period) to use.", |
| 324 | ) | 324 | ) |
| 325 | parser.add_argument( | 325 | parser.add_argument( |
| 326 | "--learning_rate_emb", | ||
| 327 | type=float, | ||
| 328 | default=1e-5, | ||
| 329 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 330 | ) | ||
| 331 | parser.add_argument( | ||
| 332 | "--train_emb", | ||
| 333 | action="store_true", | ||
| 334 | help="Keep training text embeddings.", | ||
| 335 | ) | ||
| 336 | parser.add_argument( | ||
| 326 | "--scale_lr", | 337 | "--scale_lr", |
| 327 | action="store_true", | 338 | action="store_true", |
| 328 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 339 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -731,11 +742,16 @@ def main(): | |||
| 731 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * | 742 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * |
| 732 | args.pti_batch_size * accelerator.num_processes | 743 | args.pti_batch_size * accelerator.num_processes |
| 733 | ) | 744 | ) |
| 745 | args.learning_rate_emb = ( | ||
| 746 | args.learning_rate_emb * args.pti_gradient_accumulation_steps * | ||
| 747 | args.pti_batch_size * accelerator.num_processes | ||
| 748 | ) | ||
| 734 | 749 | ||
| 735 | if args.find_lr: | 750 | if args.find_lr: |
| 736 | args.learning_rate_unet = 1e-6 | 751 | args.learning_rate_unet = 1e-6 |
| 737 | args.learning_rate_text = 1e-6 | 752 | args.learning_rate_text = 1e-6 |
| 738 | args.learning_rate_pti = 1e-6 | 753 | args.learning_rate_pti = 1e-6 |
| 754 | args.learning_rate_emb = 1e-6 | ||
| 739 | args.lr_scheduler = "exponential_growth" | 755 | args.lr_scheduler = "exponential_growth" |
| 740 | 756 | ||
| 741 | if args.optimizer == 'adam8bit': | 757 | if args.optimizer == 'adam8bit': |
| @@ -794,6 +810,9 @@ def main(): | |||
| 794 | args.lr_scheduler = "adafactor" | 810 | args.lr_scheduler = "adafactor" |
| 795 | args.lr_min_lr = args.learning_rate_unet | 811 | args.lr_min_lr = args.learning_rate_unet |
| 796 | args.learning_rate_unet = None | 812 | args.learning_rate_unet = None |
| 813 | args.learning_rate_text = None | ||
| 814 | args.learning_rate_pti = None | ||
| 815 | args.learning_rate_emb = None | ||
| 797 | elif args.optimizer == 'dadam': | 816 | elif args.optimizer == 'dadam': |
| 798 | try: | 817 | try: |
| 799 | import dadaptation | 818 | import dadaptation |
| @@ -811,6 +830,8 @@ def main(): | |||
| 811 | 830 | ||
| 812 | args.learning_rate_unet = 1.0 | 831 | args.learning_rate_unet = 1.0 |
| 813 | args.learning_rate_text = 1.0 | 832 | args.learning_rate_text = 1.0 |
| 833 | args.learning_rate_pti = 1.0 | ||
| 834 | args.learning_rate_emb = 1.0 | ||
| 814 | elif args.optimizer == 'dadan': | 835 | elif args.optimizer == 'dadan': |
| 815 | try: | 836 | try: |
| 816 | import dadaptation | 837 | import dadaptation |
| @@ -826,6 +847,8 @@ def main(): | |||
| 826 | 847 | ||
| 827 | args.learning_rate_unet = 1.0 | 848 | args.learning_rate_unet = 1.0 |
| 828 | args.learning_rate_text = 1.0 | 849 | args.learning_rate_text = 1.0 |
| 850 | args.learning_rate_pti = 1.0 | ||
| 851 | args.learning_rate_emb = 1.0 | ||
| 829 | else: | 852 | else: |
| 830 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 853 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 831 | 854 | ||
| @@ -949,7 +972,8 @@ def main(): | |||
| 949 | sample_frequency=pti_sample_frequency, | 972 | sample_frequency=pti_sample_frequency, |
| 950 | ) | 973 | ) |
| 951 | 974 | ||
| 952 | embeddings.persist() | 975 | if not args.train_emb: |
| 976 | embeddings.persist() | ||
| 953 | 977 | ||
| 954 | # LORA | 978 | # LORA |
| 955 | # -------------------------------------------------------------------------------- | 979 | # -------------------------------------------------------------------------------- |
| @@ -974,13 +998,13 @@ def main(): | |||
| 974 | 998 | ||
| 975 | params_to_optimize = [] | 999 | params_to_optimize = [] |
| 976 | group_labels = [] | 1000 | group_labels = [] |
| 977 | # if len(args.placeholder_tokens) != 0: | 1001 | if len(args.placeholder_tokens) != 0 and args.train_emb: |
| 978 | # params_to_optimize.append({ | 1002 | params_to_optimize.append({ |
| 979 | # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 1003 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 980 | # "lr": args.learning_rate_text, | 1004 | "lr": args.learning_rate_emb, |
| 981 | # "weight_decay": 0, | 1005 | "weight_decay": 0, |
| 982 | # }) | 1006 | }) |
| 983 | # group_labels.append("emb") | 1007 | group_labels.append("emb") |
| 984 | params_to_optimize += [ | 1008 | params_to_optimize += [ |
| 985 | { | 1009 | { |
| 986 | "params": ( | 1010 | "params": ( |
