diff options
author | Volpeon <git@volpeon.ink> | 2022-10-19 12:19:23 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-19 12:19:23 +0200 |
commit | b4a00845721fbc95819ad888dfd7c24013bbf4d0 (patch) | |
tree | df5888d0a52077d7fb1035939fb2b2e8547a0655 /dreambooth_plus.py | |
parent | Adapted other scripts for new prompt processing (diff) | |
download | textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.gz textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.bz2 textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.zip |
Updated Dreambooth training
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 06ff45b..413abe3 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -125,7 +125,7 @@ def parse_args(): | |||
125 | parser.add_argument( | 125 | parser.add_argument( |
126 | "--max_train_steps", | 126 | "--max_train_steps", |
127 | type=int, | 127 | type=int, |
128 | default=2400, | 128 | default=4700, |
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
130 | ) | 130 | ) |
131 | parser.add_argument( | 131 | parser.add_argument( |
@@ -142,13 +142,13 @@ def parse_args(): | |||
142 | parser.add_argument( | 142 | parser.add_argument( |
143 | "--learning_rate_unet", | 143 | "--learning_rate_unet", |
144 | type=float, | 144 | type=float, |
145 | default=5e-6, | 145 | default=2e-6, |
146 | help="Initial learning rate (after the potential warmup period) to use.", | 146 | help="Initial learning rate (after the potential warmup period) to use.", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--learning_rate_text", | 149 | "--learning_rate_text", |
150 | type=float, | 150 | type=float, |
151 | default=5e-6, | 151 | default=2e-6, |
152 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
153 | ) | 153 | ) |
154 | parser.add_argument( | 154 | parser.add_argument( |
@@ -578,6 +578,7 @@ def main(): | |||
578 | 578 | ||
579 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
580 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
581 | text_encoder.gradient_checkpointing_enable() | ||
581 | 582 | ||
582 | # slice_size = unet.config.attention_head_dim // 2 | 583 | # slice_size = unet.config.attention_head_dim // 2 |
583 | # unet.set_attention_slice(slice_size) | 584 | # unet.set_attention_slice(slice_size) |