From 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 08:40:05 +0200 Subject: Update --- train_dreambooth.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 929310b..90ca467 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -839,7 +839,10 @@ def main(): create_optimizer = partial( prodigyopt.Prodigy, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) args.learning_rate_unet = 1.0 @@ -965,9 +968,23 @@ def main(): }, { "params": ( - param for param in text_encoder.parameters() if param.requires_grad + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad + ), + "lr": learning_rate_text, + }, + { + "params": ( + param + for param in text_encoder.text_model.embeddings.token_embedding.parameters() + if param.requires_grad ), "lr": learning_rate_text, + "weight_decay": 0, }, ] group_labels = ["unet", "text"] -- cgit v1.2.3-54-g00ecf