summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index be7d6fe..4c36ae4 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -130,7 +130,8 @@ def parse_args():
130 "--vector_shuffle", 130 "--vector_shuffle",
131 type=str, 131 type=str,
132 default="auto", 132 default="auto",
133 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 133 choices=["all", "trailing", "leading", "between", "auto", "off"],
134 help='Vector shuffling algorithm.',
134 ) 135 )
135 parser.add_argument( 136 parser.add_argument(
136 "--guidance_scale", 137 "--guidance_scale",
@@ -229,10 +230,9 @@ def parse_args():
229 "--lr_scheduler", 230 "--lr_scheduler",
230 type=str, 231 type=str,
231 default="one_cycle", 232 default="one_cycle",
232 help=( 233 choices=["linear", "cosine", "cosine_with_restarts", "polynomial",
233 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 234 "constant", "constant_with_warmup", "one_cycle"],
234 ' "constant", "constant_with_warmup", "one_cycle"]' 235 help='The scheduler type to use.',
235 ),
236 ) 236 )
237 parser.add_argument( 237 parser.add_argument(
238 "--lr_warmup_epochs", 238 "--lr_warmup_epochs",
@@ -256,7 +256,7 @@ def parse_args():
256 "--lr_warmup_func", 256 "--lr_warmup_func",
257 type=str, 257 type=str,
258 default="cos", 258 default="cos",
259 help='Choose between ["linear", "cos"]' 259 choices=["linear", "cos"],
260 ) 260 )
261 parser.add_argument( 261 parser.add_argument(
262 "--lr_warmup_exp", 262 "--lr_warmup_exp",
@@ -268,7 +268,7 @@ def parse_args():
268 "--lr_annealing_func", 268 "--lr_annealing_func",
269 type=str, 269 type=str,
270 default="cos", 270 default="cos",
271 help='Choose between ["linear", "half_cos", "cos"]' 271 choices=["linear", "half_cos", "cos"],
272 ) 272 )
273 parser.add_argument( 273 parser.add_argument(
274 "--lr_annealing_exp", 274 "--lr_annealing_exp",
@@ -306,7 +306,8 @@ def parse_args():
306 "--optimizer", 306 "--optimizer",
307 type=str, 307 type=str,
308 default="dadan", 308 default="dadan",
309 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' 309 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"],
310 help='Optimizer to use'
310 ) 311 )
311 parser.add_argument( 312 parser.add_argument(
312 "--dadaptation_d0", 313 "--dadaptation_d0",