summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py38
1 files changed, 27 insertions, 11 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a7911c..8f0c6ea 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -285,9 +285,10 @@ def parse_args():
285 default=0.9999 285 default=0.9999
286 ) 286 )
287 parser.add_argument( 287 parser.add_argument(
288 "--use_8bit_adam", 288 "--optimizer",
289 action="store_true", 289 type=str,
290 help="Whether or not to use 8-bit Adam from bitsandbytes." 290 default="lion",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]'
291 ) 292 )
292 parser.add_argument( 293 parser.add_argument(
293 "--adam_beta1", 294 "--adam_beta1",
@@ -491,15 +492,34 @@ def main():
491 args.learning_rate = 1e-6 492 args.learning_rate = 1e-6
492 args.lr_scheduler = "exponential_growth" 493 args.lr_scheduler = "exponential_growth"
493 494
494 if args.use_8bit_adam: 495 if args.optimizer == 'adam8bit':
495 try: 496 try:
496 import bitsandbytes as bnb 497 import bitsandbytes as bnb
497 except ImportError: 498 except ImportError:
498 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 499 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
499 500
500 optimizer_class = bnb.optim.AdamW8bit 501 create_optimizer = partial(
502 bnb.optim.AdamW8bit,
503 betas=(args.adam_beta1, args.adam_beta2),
504 weight_decay=args.adam_weight_decay,
505 eps=args.adam_epsilon,
506 amsgrad=args.adam_amsgrad,
507 )
508 elif args.optimizer == 'adam':
509 create_optimizer = partial(
510 torch.optim.AdamW,
511 betas=(args.adam_beta1, args.adam_beta2),
512 weight_decay=args.adam_weight_decay,
513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad,
515 )
501 else: 516 else:
502 optimizer_class = torch.optim.AdamW 517 try:
518 from lion_pytorch import Lion
519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521
522 create_optimizer = partial(Lion, use_triton=True)
503 523
504 trainer = partial( 524 trainer = partial(
505 train, 525 train,
@@ -540,17 +560,13 @@ def main():
540 ) 560 )
541 datamodule.setup() 561 datamodule.setup()
542 562
543 optimizer = optimizer_class( 563 optimizer = create_optimizer(
544 itertools.chain( 564 itertools.chain(
545 unet.parameters(), 565 unet.parameters(),
546 text_encoder.text_model.encoder.parameters(), 566 text_encoder.text_model.encoder.parameters(),
547 text_encoder.text_model.final_layer_norm.parameters(), 567 text_encoder.text_model.final_layer_norm.parameters(),
548 ), 568 ),
549 lr=args.learning_rate, 569 lr=args.learning_rate,
550 betas=(args.adam_beta1, args.adam_beta2),
551 weight_decay=args.adam_weight_decay,
552 eps=args.adam_epsilon,
553 amsgrad=args.adam_amsgrad,
554 ) 570 )
555 571
556 lr_scheduler = get_scheduler( 572 lr_scheduler = get_scheduler(