diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a7911c..8f0c6ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -285,9 +285,10 @@ def parse_args(): | |||
285 | default=0.9999 | 285 | default=0.9999 |
286 | ) | 286 | ) |
287 | parser.add_argument( | 287 | parser.add_argument( |
288 | "--use_8bit_adam", | 288 | "--optimizer", |
289 | action="store_true", | 289 | type=str, |
290 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 290 | default="lion", |
291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | ||
291 | ) | 292 | ) |
292 | parser.add_argument( | 293 | parser.add_argument( |
293 | "--adam_beta1", | 294 | "--adam_beta1", |
@@ -491,15 +492,34 @@ def main(): | |||
491 | args.learning_rate = 1e-6 | 492 | args.learning_rate = 1e-6 |
492 | args.lr_scheduler = "exponential_growth" | 493 | args.lr_scheduler = "exponential_growth" |
493 | 494 | ||
494 | if args.use_8bit_adam: | 495 | if args.optimizer == 'adam8bit': |
495 | try: | 496 | try: |
496 | import bitsandbytes as bnb | 497 | import bitsandbytes as bnb |
497 | except ImportError: | 498 | except ImportError: |
498 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 499 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") |
499 | 500 | ||
500 | optimizer_class = bnb.optim.AdamW8bit | 501 | create_optimizer = partial( |
502 | bnb.optim.AdamW8bit, | ||
503 | betas=(args.adam_beta1, args.adam_beta2), | ||
504 | weight_decay=args.adam_weight_decay, | ||
505 | eps=args.adam_epsilon, | ||
506 | amsgrad=args.adam_amsgrad, | ||
507 | ) | ||
508 | elif args.optimizer == 'adam': | ||
509 | create_optimizer = partial( | ||
510 | torch.optim.AdamW, | ||
511 | betas=(args.adam_beta1, args.adam_beta2), | ||
512 | weight_decay=args.adam_weight_decay, | ||
513 | eps=args.adam_epsilon, | ||
514 | amsgrad=args.adam_amsgrad, | ||
515 | ) | ||
501 | else: | 516 | else: |
502 | optimizer_class = torch.optim.AdamW | 517 | try: |
518 | from lion_pytorch import Lion | ||
519 | except ImportError: | ||
520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
521 | |||
522 | create_optimizer = partial(Lion, use_triton=True) | ||
503 | 523 | ||
504 | trainer = partial( | 524 | trainer = partial( |
505 | train, | 525 | train, |
@@ -540,17 +560,13 @@ def main(): | |||
540 | ) | 560 | ) |
541 | datamodule.setup() | 561 | datamodule.setup() |
542 | 562 | ||
543 | optimizer = optimizer_class( | 563 | optimizer = create_optimizer( |
544 | itertools.chain( | 564 | itertools.chain( |
545 | unet.parameters(), | 565 | unet.parameters(), |
546 | text_encoder.text_model.encoder.parameters(), | 566 | text_encoder.text_model.encoder.parameters(), |
547 | text_encoder.text_model.final_layer_norm.parameters(), | 567 | text_encoder.text_model.final_layer_norm.parameters(), |
548 | ), | 568 | ), |
549 | lr=args.learning_rate, | 569 | lr=args.learning_rate, |
550 | betas=(args.adam_beta1, args.adam_beta2), | ||
551 | weight_decay=args.adam_weight_decay, | ||
552 | eps=args.adam_epsilon, | ||
553 | amsgrad=args.adam_amsgrad, | ||
554 | ) | 570 | ) |
555 | 571 | ||
556 | lr_scheduler = get_scheduler( | 572 | lr_scheduler = get_scheduler( |