diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-19 12:19:23 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-19 12:19:23 +0200 |
| commit | b4a00845721fbc95819ad888dfd7c24013bbf4d0 (patch) | |
| tree | df5888d0a52077d7fb1035939fb2b2e8547a0655 /dreambooth_plus.py | |
| parent | Adapted other scripts for new prompt processing (diff) | |
| download | textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.gz textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.bz2 textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.zip | |
Updated Dreambooth training
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 06ff45b..413abe3 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -125,7 +125,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 125 | parser.add_argument( |
| 126 | "--max_train_steps", | 126 | "--max_train_steps", |
| 127 | type=int, | 127 | type=int, |
| 128 | default=2400, | 128 | default=4700, |
| 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 130 | ) | 130 | ) |
| 131 | parser.add_argument( | 131 | parser.add_argument( |
| @@ -142,13 +142,13 @@ def parse_args(): | |||
| 142 | parser.add_argument( | 142 | parser.add_argument( |
| 143 | "--learning_rate_unet", | 143 | "--learning_rate_unet", |
| 144 | type=float, | 144 | type=float, |
| 145 | default=5e-6, | 145 | default=2e-6, |
| 146 | help="Initial learning rate (after the potential warmup period) to use.", | 146 | help="Initial learning rate (after the potential warmup period) to use.", |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--learning_rate_text", | 149 | "--learning_rate_text", |
| 150 | type=float, | 150 | type=float, |
| 151 | default=5e-6, | 151 | default=2e-6, |
| 152 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
| 153 | ) | 153 | ) |
| 154 | parser.add_argument( | 154 | parser.add_argument( |
| @@ -578,6 +578,7 @@ def main(): | |||
| 578 | 578 | ||
| 579 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
| 580 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
| 581 | text_encoder.gradient_checkpointing_enable() | ||
| 581 | 582 | ||
| 582 | # slice_size = unet.config.attention_head_dim // 2 | 583 | # slice_size = unet.config.attention_head_dim // 2 |
| 583 | # unet.set_attention_slice(slice_size) | 584 | # unet.set_attention_slice(slice_size) |
