summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 929310b..90ca467 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -839,7 +839,10 @@ def main():
839 839
840 create_optimizer = partial( 840 create_optimizer = partial(
841 prodigyopt.Prodigy, 841 prodigyopt.Prodigy,
842 betas=(args.adam_beta1, args.adam_beta2),
842 weight_decay=args.adam_weight_decay, 843 weight_decay=args.adam_weight_decay,
844 eps=args.adam_epsilon,
845 d0=args.dadaptation_d0,
843 ) 846 )
844 847
845 args.learning_rate_unet = 1.0 848 args.learning_rate_unet = 1.0
@@ -965,9 +968,23 @@ def main():
965 }, 968 },
966 { 969 {
967 "params": ( 970 "params": (
968 param for param in text_encoder.parameters() if param.requires_grad 971 param
972 for param in itertools.chain(
973 text_encoder.text_model.encoder.parameters(),
974 text_encoder.text_model.final_layer_norm.parameters(),
975 )
976 if param.requires_grad
977 ),
978 "lr": learning_rate_text,
979 },
980 {
981 "params": (
982 param
983 for param in text_encoder.text_model.embeddings.token_embedding.parameters()
984 if param.requires_grad
969 ), 985 ),
970 "lr": learning_rate_text, 986 "lr": learning_rate_text,
987 "weight_decay": 0,
971 }, 988 },
972 ] 989 ]
973 group_labels = ["unet", "text"] 990 group_labels = ["unet", "text"]