summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48b7926..be7d6fe 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -306,7 +306,7 @@ def parse_args():
306 "--optimizer", 306 "--optimizer",
307 type=str, 307 type=str,
308 default="dadan", 308 default="dadan",
309 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 309 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
310 ) 310 )
311 parser.add_argument( 311 parser.add_argument(
312 "--dadaptation_d0", 312 "--dadaptation_d0",
@@ -317,13 +317,13 @@ def parse_args():
317 parser.add_argument( 317 parser.add_argument(
318 "--adam_beta1", 318 "--adam_beta1",
319 type=float, 319 type=float,
320 default=0.9, 320 default=None,
321 help="The beta1 parameter for the Adam optimizer." 321 help="The beta1 parameter for the Adam optimizer."
322 ) 322 )
323 parser.add_argument( 323 parser.add_argument(
324 "--adam_beta2", 324 "--adam_beta2",
325 type=float, 325 type=float,
326 default=0.999, 326 default=None,
327 help="The beta2 parameter for the Adam optimizer." 327 help="The beta2 parameter for the Adam optimizer."
328 ) 328 )
329 parser.add_argument( 329 parser.add_argument(
@@ -450,6 +450,18 @@ def parse_args():
450 if args.output_dir is None: 450 if args.output_dir is None:
451 raise ValueError("You must specify --output_dir") 451 raise ValueError("You must specify --output_dir")
452 452
453 if args.adam_beta1 is None:
454 if args.optimizer in ('adam', 'adam8bit'):
455 args.adam_beta1 = 0.9
456 elif args.optimizer == 'lion':
457 args.adam_beta1 = 0.95
458
459 if args.adam_beta2 is None:
460 if args.optimizer in ('adam', 'adam8bit'):
461 args.adam_beta2 = 0.999
462 elif args.optimizer == 'lion':
463 args.adam_beta2 = 0.98
464
453 return args 465 return args
454 466
455 467
@@ -536,6 +548,18 @@ def main():
536 eps=args.adam_epsilon, 548 eps=args.adam_epsilon,
537 amsgrad=args.adam_amsgrad, 549 amsgrad=args.adam_amsgrad,
538 ) 550 )
551 elif args.optimizer == 'lion':
552 try:
553 import lion_pytorch
554 except ImportError:
555 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
556
557 create_optimizer = partial(
558 lion_pytorch.Lion,
559 betas=(args.adam_beta1, args.adam_beta2),
560 weight_decay=args.adam_weight_decay,
561 use_triton=True,
562 )
539 elif args.optimizer == 'adafactor': 563 elif args.optimizer == 'adafactor':
540 create_optimizer = partial( 564 create_optimizer = partial(
541 transformers.optimization.Adafactor, 565 transformers.optimization.Adafactor,