From b4a00845721fbc95819ad888dfd7c24013bbf4d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 19 Oct 2022 12:19:23 +0200 Subject: Updated Dreambooth training --- dreambooth_plus.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 06ff45b..413abe3 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -125,7 +125,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=2400, + default=4700, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -142,13 +142,13 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=5e-6, + default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, - default=5e-6, + default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -578,6 +578,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() # slice_size = unet.config.attention_head_dim // 2 # unet.set_attention_slice(slice_size) -- cgit v1.2.3-54-g00ecf