diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a7911c..8f0c6ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -285,9 +285,10 @@ def parse_args(): | |||
| 285 | default=0.9999 | 285 | default=0.9999 |
| 286 | ) | 286 | ) |
| 287 | parser.add_argument( | 287 | parser.add_argument( |
| 288 | "--use_8bit_adam", | 288 | "--optimizer", |
| 289 | action="store_true", | 289 | type=str, |
| 290 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 290 | default="lion", |
| 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | ||
| 291 | ) | 292 | ) |
| 292 | parser.add_argument( | 293 | parser.add_argument( |
| 293 | "--adam_beta1", | 294 | "--adam_beta1", |
| @@ -491,15 +492,34 @@ def main(): | |||
| 491 | args.learning_rate = 1e-6 | 492 | args.learning_rate = 1e-6 |
| 492 | args.lr_scheduler = "exponential_growth" | 493 | args.lr_scheduler = "exponential_growth" |
| 493 | 494 | ||
| 494 | if args.use_8bit_adam: | 495 | if args.optimizer == 'adam8bit': |
| 495 | try: | 496 | try: |
| 496 | import bitsandbytes as bnb | 497 | import bitsandbytes as bnb |
| 497 | except ImportError: | 498 | except ImportError: |
| 498 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 499 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") |
| 499 | 500 | ||
| 500 | optimizer_class = bnb.optim.AdamW8bit | 501 | create_optimizer = partial( |
| 502 | bnb.optim.AdamW8bit, | ||
| 503 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 504 | weight_decay=args.adam_weight_decay, | ||
| 505 | eps=args.adam_epsilon, | ||
| 506 | amsgrad=args.adam_amsgrad, | ||
| 507 | ) | ||
| 508 | elif args.optimizer == 'adam': | ||
| 509 | create_optimizer = partial( | ||
| 510 | torch.optim.AdamW, | ||
| 511 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 512 | weight_decay=args.adam_weight_decay, | ||
| 513 | eps=args.adam_epsilon, | ||
| 514 | amsgrad=args.adam_amsgrad, | ||
| 515 | ) | ||
| 501 | else: | 516 | else: |
| 502 | optimizer_class = torch.optim.AdamW | 517 | try: |
| 518 | from lion_pytorch import Lion | ||
| 519 | except ImportError: | ||
| 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 521 | |||
| 522 | create_optimizer = partial(Lion, use_triton=True) | ||
| 503 | 523 | ||
| 504 | trainer = partial( | 524 | trainer = partial( |
| 505 | train, | 525 | train, |
| @@ -540,17 +560,13 @@ def main(): | |||
| 540 | ) | 560 | ) |
| 541 | datamodule.setup() | 561 | datamodule.setup() |
| 542 | 562 | ||
| 543 | optimizer = optimizer_class( | 563 | optimizer = create_optimizer( |
| 544 | itertools.chain( | 564 | itertools.chain( |
| 545 | unet.parameters(), | 565 | unet.parameters(), |
| 546 | text_encoder.text_model.encoder.parameters(), | 566 | text_encoder.text_model.encoder.parameters(), |
| 547 | text_encoder.text_model.final_layer_norm.parameters(), | 567 | text_encoder.text_model.final_layer_norm.parameters(), |
| 548 | ), | 568 | ), |
| 549 | lr=args.learning_rate, | 569 | lr=args.learning_rate, |
| 550 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 551 | weight_decay=args.adam_weight_decay, | ||
| 552 | eps=args.adam_epsilon, | ||
| 553 | amsgrad=args.adam_amsgrad, | ||
| 554 | ) | 570 | ) |
| 555 | 571 | ||
| 556 | lr_scheduler = get_scheduler( | 572 | lr_scheduler = get_scheduler( |
