diff options
author | Volpeon <git@volpeon.ink> | 2023-04-13 07:14:24 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-13 07:14:24 +0200 |
commit | a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e (patch) | |
tree | 6a695b2b5a73cebc35ff9e581c70f1a0e75b62e8 /train_lora.py | |
parent | Experimental convnext discriminator support (diff) | |
download | textual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.tar.gz textual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.tar.bz2 textual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 73 |
1 files changed, 46 insertions, 27 deletions
diff --git a/train_lora.py b/train_lora.py index 29e40b2..073e939 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -87,6 +87,12 @@ def parse_args(): | |||
87 | help="How many cycles to run automatically." | 87 | help="How many cycles to run automatically." |
88 | ) | 88 | ) |
89 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--cycle_decay", | ||
91 | type=float, | ||
92 | default=1.0, | ||
93 | help="Learning rate decay per cycle." | ||
94 | ) | ||
95 | parser.add_argument( | ||
90 | "--placeholder_tokens", | 96 | "--placeholder_tokens", |
91 | type=str, | 97 | type=str, |
92 | nargs='*', | 98 | nargs='*', |
@@ -924,39 +930,15 @@ def main(): | |||
924 | if args.sample_num is not None: | 930 | if args.sample_num is not None: |
925 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 931 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
926 | 932 | ||
927 | params_to_optimize = [] | ||
928 | group_labels = [] | 933 | group_labels = [] |
929 | if len(args.placeholder_tokens) != 0: | 934 | if len(args.placeholder_tokens) != 0: |
930 | params_to_optimize.append({ | ||
931 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
932 | "lr": args.learning_rate_emb, | ||
933 | "weight_decay": 0, | ||
934 | }) | ||
935 | group_labels.append("emb") | 935 | group_labels.append("emb") |
936 | params_to_optimize += [ | ||
937 | { | ||
938 | "params": ( | ||
939 | param | ||
940 | for param in unet.parameters() | ||
941 | if param.requires_grad | ||
942 | ), | ||
943 | "lr": args.learning_rate_unet, | ||
944 | }, | ||
945 | { | ||
946 | "params": ( | ||
947 | param | ||
948 | for param in itertools.chain( | ||
949 | text_encoder.text_model.encoder.parameters(), | ||
950 | text_encoder.text_model.final_layer_norm.parameters(), | ||
951 | ) | ||
952 | if param.requires_grad | ||
953 | ), | ||
954 | "lr": args.learning_rate_text, | ||
955 | }, | ||
956 | ] | ||
957 | group_labels += ["unet", "text"] | 936 | group_labels += ["unet", "text"] |
958 | 937 | ||
959 | training_iter = 0 | 938 | training_iter = 0 |
939 | learning_rate_emb = args.learning_rate_emb | ||
940 | learning_rate_unet = args.learning_rate_unet | ||
941 | learning_rate_text = args.learning_rate_text | ||
960 | 942 | ||
961 | lora_project = "lora" | 943 | lora_project = "lora" |
962 | 944 | ||
@@ -973,6 +955,37 @@ def main(): | |||
973 | print(f"============ LoRA cycle {training_iter + 1} ============") | 955 | print(f"============ LoRA cycle {training_iter + 1} ============") |
974 | print("") | 956 | print("") |
975 | 957 | ||
958 | params_to_optimize = [] | ||
959 | |||
960 | if len(args.placeholder_tokens) != 0: | ||
961 | params_to_optimize.append({ | ||
962 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
963 | "lr": learning_rate_emb, | ||
964 | "weight_decay": 0, | ||
965 | }) | ||
966 | group_labels.append("emb") | ||
967 | params_to_optimize += [ | ||
968 | { | ||
969 | "params": ( | ||
970 | param | ||
971 | for param in unet.parameters() | ||
972 | if param.requires_grad | ||
973 | ), | ||
974 | "lr": learning_rate_unet, | ||
975 | }, | ||
976 | { | ||
977 | "params": ( | ||
978 | param | ||
979 | for param in itertools.chain( | ||
980 | text_encoder.text_model.encoder.parameters(), | ||
981 | text_encoder.text_model.final_layer_norm.parameters(), | ||
982 | ) | ||
983 | if param.requires_grad | ||
984 | ), | ||
985 | "lr": learning_rate_text, | ||
986 | }, | ||
987 | ] | ||
988 | |||
976 | lora_optimizer = create_optimizer(params_to_optimize) | 989 | lora_optimizer = create_optimizer(params_to_optimize) |
977 | 990 | ||
978 | lora_lr_scheduler = create_lr_scheduler( | 991 | lora_lr_scheduler = create_lr_scheduler( |
@@ -1002,6 +1015,12 @@ def main(): | |||
1002 | ) | 1015 | ) |
1003 | 1016 | ||
1004 | training_iter += 1 | 1017 | training_iter += 1 |
1018 | if args.learning_rate_emb is not None: | ||
1019 | learning_rate_emb *= args.cycle_decay | ||
1020 | if args.learning_rate_unet is not None: | ||
1021 | learning_rate_unet *= args.cycle_decay | ||
1022 | if args.learning_rate_text is not None: | ||
1023 | learning_rate_text *= args.cycle_decay | ||
1005 | 1024 | ||
1006 | accelerator.end_training() | 1025 | accelerator.end_training() |
1007 | 1026 | ||