summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-19 12:19:23 +0200
committerVolpeon <git@volpeon.ink>2022-10-19 12:19:23 +0200
commitb4a00845721fbc95819ad888dfd7c24013bbf4d0 (patch)
treedf5888d0a52077d7fb1035939fb2b2e8547a0655 /dreambooth_plus.py
parentAdapted other scripts for new prompt processing (diff)
downloadtextual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.gz
textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.bz2
textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.zip
Updated Dreambooth training
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 06ff45b..413abe3 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -125,7 +125,7 @@ def parse_args():
125 parser.add_argument( 125 parser.add_argument(
126 "--max_train_steps", 126 "--max_train_steps",
127 type=int, 127 type=int,
128 default=2400, 128 default=4700,
129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 ) 130 )
131 parser.add_argument( 131 parser.add_argument(
@@ -142,13 +142,13 @@ def parse_args():
142 parser.add_argument( 142 parser.add_argument(
143 "--learning_rate_unet", 143 "--learning_rate_unet",
144 type=float, 144 type=float,
145 default=5e-6, 145 default=2e-6,
146 help="Initial learning rate (after the potential warmup period) to use.", 146 help="Initial learning rate (after the potential warmup period) to use.",
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--learning_rate_text", 149 "--learning_rate_text",
150 type=float, 150 type=float,
151 default=5e-6, 151 default=2e-6,
152 help="Initial learning rate (after the potential warmup period) to use.", 152 help="Initial learning rate (after the potential warmup period) to use.",
153 ) 153 )
154 parser.add_argument( 154 parser.add_argument(
@@ -578,6 +578,7 @@ def main():
578 578
579 if args.gradient_checkpointing: 579 if args.gradient_checkpointing:
580 unet.enable_gradient_checkpointing() 580 unet.enable_gradient_checkpointing()
581 text_encoder.gradient_checkpointing_enable()
581 582
582 # slice_size = unet.config.attention_head_dim // 2 583 # slice_size = unet.config.attention_head_dim // 2
583 # unet.set_attention_slice(slice_size) 584 # unet.set_attention_slice(slice_size)