summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py73
1 files changed, 46 insertions, 27 deletions
diff --git a/train_lora.py b/train_lora.py
index 29e40b2..073e939 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -87,6 +87,12 @@ def parse_args():
87 help="How many cycles to run automatically." 87 help="How many cycles to run automatically."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--cycle_decay",
91 type=float,
92 default=1.0,
93 help="Learning rate decay per cycle."
94 )
95 parser.add_argument(
90 "--placeholder_tokens", 96 "--placeholder_tokens",
91 type=str, 97 type=str,
92 nargs='*', 98 nargs='*',
@@ -924,39 +930,15 @@ def main():
924 if args.sample_num is not None: 930 if args.sample_num is not None:
925 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 931 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
926 932
927 params_to_optimize = []
928 group_labels = [] 933 group_labels = []
929 if len(args.placeholder_tokens) != 0: 934 if len(args.placeholder_tokens) != 0:
930 params_to_optimize.append({
931 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
932 "lr": args.learning_rate_emb,
933 "weight_decay": 0,
934 })
935 group_labels.append("emb") 935 group_labels.append("emb")
936 params_to_optimize += [
937 {
938 "params": (
939 param
940 for param in unet.parameters()
941 if param.requires_grad
942 ),
943 "lr": args.learning_rate_unet,
944 },
945 {
946 "params": (
947 param
948 for param in itertools.chain(
949 text_encoder.text_model.encoder.parameters(),
950 text_encoder.text_model.final_layer_norm.parameters(),
951 )
952 if param.requires_grad
953 ),
954 "lr": args.learning_rate_text,
955 },
956 ]
957 group_labels += ["unet", "text"] 936 group_labels += ["unet", "text"]
958 937
959 training_iter = 0 938 training_iter = 0
939 learning_rate_emb = args.learning_rate_emb
940 learning_rate_unet = args.learning_rate_unet
941 learning_rate_text = args.learning_rate_text
960 942
961 lora_project = "lora" 943 lora_project = "lora"
962 944
@@ -973,6 +955,37 @@ def main():
973 print(f"============ LoRA cycle {training_iter + 1} ============") 955 print(f"============ LoRA cycle {training_iter + 1} ============")
974 print("") 956 print("")
975 957
958 params_to_optimize = []
959
960 if len(args.placeholder_tokens) != 0:
961 params_to_optimize.append({
962 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
963 "lr": learning_rate_emb,
964 "weight_decay": 0,
965 })
966 group_labels.append("emb")
967 params_to_optimize += [
968 {
969 "params": (
970 param
971 for param in unet.parameters()
972 if param.requires_grad
973 ),
974 "lr": learning_rate_unet,
975 },
976 {
977 "params": (
978 param
979 for param in itertools.chain(
980 text_encoder.text_model.encoder.parameters(),
981 text_encoder.text_model.final_layer_norm.parameters(),
982 )
983 if param.requires_grad
984 ),
985 "lr": learning_rate_text,
986 },
987 ]
988
976 lora_optimizer = create_optimizer(params_to_optimize) 989 lora_optimizer = create_optimizer(params_to_optimize)
977 990
978 lora_lr_scheduler = create_lr_scheduler( 991 lora_lr_scheduler = create_lr_scheduler(
@@ -1002,6 +1015,12 @@ def main():
1002 ) 1015 )
1003 1016
1004 training_iter += 1 1017 training_iter += 1
1018 if args.learning_rate_emb is not None:
1019 learning_rate_emb *= args.cycle_decay
1020 if args.learning_rate_unet is not None:
1021 learning_rate_unet *= args.cycle_decay
1022 if args.learning_rate_text is not None:
1023 learning_rate_text *= args.cycle_decay
1005 1024
1006 accelerator.end_training() 1025 accelerator.end_training()
1007 1026