diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 929310b..90ca467 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -839,7 +839,10 @@ def main(): | |||
839 | 839 | ||
840 | create_optimizer = partial( | 840 | create_optimizer = partial( |
841 | prodigyopt.Prodigy, | 841 | prodigyopt.Prodigy, |
842 | betas=(args.adam_beta1, args.adam_beta2), | ||
842 | weight_decay=args.adam_weight_decay, | 843 | weight_decay=args.adam_weight_decay, |
844 | eps=args.adam_epsilon, | ||
845 | d0=args.dadaptation_d0, | ||
843 | ) | 846 | ) |
844 | 847 | ||
845 | args.learning_rate_unet = 1.0 | 848 | args.learning_rate_unet = 1.0 |
@@ -965,9 +968,23 @@ def main(): | |||
965 | }, | 968 | }, |
966 | { | 969 | { |
967 | "params": ( | 970 | "params": ( |
968 | param for param in text_encoder.parameters() if param.requires_grad | 971 | param |
972 | for param in itertools.chain( | ||
973 | text_encoder.text_model.encoder.parameters(), | ||
974 | text_encoder.text_model.final_layer_norm.parameters(), | ||
975 | ) | ||
976 | if param.requires_grad | ||
977 | ), | ||
978 | "lr": learning_rate_text, | ||
979 | }, | ||
980 | { | ||
981 | "params": ( | ||
982 | param | ||
983 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
984 | if param.requires_grad | ||
969 | ), | 985 | ), |
970 | "lr": learning_rate_text, | 986 | "lr": learning_rate_text, |
987 | "weight_decay": 0, | ||
971 | }, | 988 | }, |
972 | ] | 989 | ] |
973 | group_labels = ["unet", "text"] | 990 | group_labels = ["unet", "text"] |