diff options
| -rw-r--r-- | environment.yaml | 1 | ||||
| -rw-r--r-- | train_dreambooth.py | 30 | ||||
| -rw-r--r-- | train_lora.py | 30 | ||||
| -rw-r--r-- | train_ti.py | 30 |
4 files changed, 82 insertions, 9 deletions
diff --git a/environment.yaml b/environment.yaml index 8868532..1de76bd 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -18,6 +18,7 @@ dependencies: | |||
| 18 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 18 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
| 19 | - accelerate==0.17.1 | 19 | - accelerate==0.17.1 |
| 20 | - bitsandbytes==0.37.2 | 20 | - bitsandbytes==0.37.2 |
| 21 | - lion-pytorch==0.0.7 | ||
| 21 | - peft==0.2.0 | 22 | - peft==0.2.0 |
| 22 | - python-slugify>=6.1.2 | 23 | - python-slugify>=6.1.2 |
| 23 | - safetensors==0.3.0 | 24 | - safetensors==0.3.0 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48b7926..be7d6fe 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -306,7 +306,7 @@ def parse_args(): | |||
| 306 | "--optimizer", | 306 | "--optimizer", |
| 307 | type=str, | 307 | type=str, |
| 308 | default="dadan", | 308 | default="dadan", |
| 309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
| 310 | ) | 310 | ) |
| 311 | parser.add_argument( | 311 | parser.add_argument( |
| 312 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
| @@ -317,13 +317,13 @@ def parse_args(): | |||
| 317 | parser.add_argument( | 317 | parser.add_argument( |
| 318 | "--adam_beta1", | 318 | "--adam_beta1", |
| 319 | type=float, | 319 | type=float, |
| 320 | default=0.9, | 320 | default=None, |
| 321 | help="The beta1 parameter for the Adam optimizer." | 321 | help="The beta1 parameter for the Adam optimizer." |
| 322 | ) | 322 | ) |
| 323 | parser.add_argument( | 323 | parser.add_argument( |
| 324 | "--adam_beta2", | 324 | "--adam_beta2", |
| 325 | type=float, | 325 | type=float, |
| 326 | default=0.999, | 326 | default=None, |
| 327 | help="The beta2 parameter for the Adam optimizer." | 327 | help="The beta2 parameter for the Adam optimizer." |
| 328 | ) | 328 | ) |
| 329 | parser.add_argument( | 329 | parser.add_argument( |
| @@ -450,6 +450,18 @@ def parse_args(): | |||
| 450 | if args.output_dir is None: | 450 | if args.output_dir is None: |
| 451 | raise ValueError("You must specify --output_dir") | 451 | raise ValueError("You must specify --output_dir") |
| 452 | 452 | ||
| 453 | if args.adam_beta1 is None: | ||
| 454 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 455 | args.adam_beta1 = 0.9 | ||
| 456 | elif args.optimizer == 'lion': | ||
| 457 | args.adam_beta1 = 0.95 | ||
| 458 | |||
| 459 | if args.adam_beta2 is None: | ||
| 460 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 461 | args.adam_beta2 = 0.999 | ||
| 462 | elif args.optimizer == 'lion': | ||
| 463 | args.adam_beta2 = 0.98 | ||
| 464 | |||
| 453 | return args | 465 | return args |
| 454 | 466 | ||
| 455 | 467 | ||
| @@ -536,6 +548,18 @@ def main(): | |||
| 536 | eps=args.adam_epsilon, | 548 | eps=args.adam_epsilon, |
| 537 | amsgrad=args.adam_amsgrad, | 549 | amsgrad=args.adam_amsgrad, |
| 538 | ) | 550 | ) |
| 551 | elif args.optimizer == 'lion': | ||
| 552 | try: | ||
| 553 | import lion_pytorch | ||
| 554 | except ImportError: | ||
| 555 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
| 556 | |||
| 557 | create_optimizer = partial( | ||
| 558 | lion_pytorch.Lion, | ||
| 559 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 560 | weight_decay=args.adam_weight_decay, | ||
| 561 | use_triton=True, | ||
| 562 | ) | ||
| 539 | elif args.optimizer == 'adafactor': | 563 | elif args.optimizer == 'adafactor': |
| 540 | create_optimizer = partial( | 564 | create_optimizer = partial( |
| 541 | transformers.optimization.Adafactor, | 565 | transformers.optimization.Adafactor, |
diff --git a/train_lora.py b/train_lora.py index cf73645..a0cd174 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -318,7 +318,7 @@ def parse_args(): | |||
| 318 | "--optimizer", | 318 | "--optimizer", |
| 319 | type=str, | 319 | type=str, |
| 320 | default="dadan", | 320 | default="dadan", |
| 321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
| 322 | ) | 322 | ) |
| 323 | parser.add_argument( | 323 | parser.add_argument( |
| 324 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
| @@ -329,13 +329,13 @@ def parse_args(): | |||
| 329 | parser.add_argument( | 329 | parser.add_argument( |
| 330 | "--adam_beta1", | 330 | "--adam_beta1", |
| 331 | type=float, | 331 | type=float, |
| 332 | default=0.9, | 332 | default=None, |
| 333 | help="The beta1 parameter for the Adam optimizer." | 333 | help="The beta1 parameter for the Adam optimizer." |
| 334 | ) | 334 | ) |
| 335 | parser.add_argument( | 335 | parser.add_argument( |
| 336 | "--adam_beta2", | 336 | "--adam_beta2", |
| 337 | type=float, | 337 | type=float, |
| 338 | default=0.999, | 338 | default=None, |
| 339 | help="The beta2 parameter for the Adam optimizer." | 339 | help="The beta2 parameter for the Adam optimizer." |
| 340 | ) | 340 | ) |
| 341 | parser.add_argument( | 341 | parser.add_argument( |
| @@ -468,6 +468,18 @@ def parse_args(): | |||
| 468 | if args.output_dir is None: | 468 | if args.output_dir is None: |
| 469 | raise ValueError("You must specify --output_dir") | 469 | raise ValueError("You must specify --output_dir") |
| 470 | 470 | ||
| 471 | if args.adam_beta1 is None: | ||
| 472 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 473 | args.adam_beta1 = 0.9 | ||
| 474 | elif args.optimizer == 'lion': | ||
| 475 | args.adam_beta1 = 0.95 | ||
| 476 | |||
| 477 | if args.adam_beta2 is None: | ||
| 478 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 479 | args.adam_beta2 = 0.999 | ||
| 480 | elif args.optimizer == 'lion': | ||
| 481 | args.adam_beta2 = 0.98 | ||
| 482 | |||
| 471 | return args | 483 | return args |
| 472 | 484 | ||
| 473 | 485 | ||
| @@ -568,6 +580,18 @@ def main(): | |||
| 568 | eps=args.adam_epsilon, | 580 | eps=args.adam_epsilon, |
| 569 | amsgrad=args.adam_amsgrad, | 581 | amsgrad=args.adam_amsgrad, |
| 570 | ) | 582 | ) |
| 583 | elif args.optimizer == 'lion': | ||
| 584 | try: | ||
| 585 | import lion_pytorch | ||
| 586 | except ImportError: | ||
| 587 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
| 588 | |||
| 589 | create_optimizer = partial( | ||
| 590 | lion_pytorch.Lion, | ||
| 591 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 592 | weight_decay=args.adam_weight_decay, | ||
| 593 | use_triton=True, | ||
| 594 | ) | ||
| 571 | elif args.optimizer == 'adafactor': | 595 | elif args.optimizer == 'adafactor': |
| 572 | create_optimizer = partial( | 596 | create_optimizer = partial( |
| 573 | transformers.optimization.Adafactor, | 597 | transformers.optimization.Adafactor, |
diff --git a/train_ti.py b/train_ti.py index 651dfbe..c242625 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -330,7 +330,7 @@ def parse_args(): | |||
| 330 | "--optimizer", | 330 | "--optimizer", |
| 331 | type=str, | 331 | type=str, |
| 332 | default="dadan", | 332 | default="dadan", |
| 333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
| 334 | ) | 334 | ) |
| 335 | parser.add_argument( | 335 | parser.add_argument( |
| 336 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
| @@ -341,13 +341,13 @@ def parse_args(): | |||
| 341 | parser.add_argument( | 341 | parser.add_argument( |
| 342 | "--adam_beta1", | 342 | "--adam_beta1", |
| 343 | type=float, | 343 | type=float, |
| 344 | default=0.9, | 344 | default=None, |
| 345 | help="The beta1 parameter for the Adam optimizer." | 345 | help="The beta1 parameter for the Adam optimizer." |
| 346 | ) | 346 | ) |
| 347 | parser.add_argument( | 347 | parser.add_argument( |
| 348 | "--adam_beta2", | 348 | "--adam_beta2", |
| 349 | type=float, | 349 | type=float, |
| 350 | default=0.999, | 350 | default=None, |
| 351 | help="The beta2 parameter for the Adam optimizer." | 351 | help="The beta2 parameter for the Adam optimizer." |
| 352 | ) | 352 | ) |
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| @@ -566,6 +566,18 @@ def parse_args(): | |||
| 566 | if args.output_dir is None: | 566 | if args.output_dir is None: |
| 567 | raise ValueError("You must specify --output_dir") | 567 | raise ValueError("You must specify --output_dir") |
| 568 | 568 | ||
| 569 | if args.adam_beta1 is None: | ||
| 570 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 571 | args.adam_beta1 = 0.9 | ||
| 572 | elif args.optimizer == 'lion': | ||
| 573 | args.adam_beta1 = 0.95 | ||
| 574 | |||
| 575 | if args.adam_beta2 is None: | ||
| 576 | if args.optimizer in ('adam', 'adam8bit'): | ||
| 577 | args.adam_beta2 = 0.999 | ||
| 578 | elif args.optimizer == 'lion': | ||
| 579 | args.adam_beta2 = 0.98 | ||
| 580 | |||
| 569 | return args | 581 | return args |
| 570 | 582 | ||
| 571 | 583 | ||
| @@ -666,6 +678,18 @@ def main(): | |||
| 666 | eps=args.adam_epsilon, | 678 | eps=args.adam_epsilon, |
| 667 | amsgrad=args.adam_amsgrad, | 679 | amsgrad=args.adam_amsgrad, |
| 668 | ) | 680 | ) |
| 681 | elif args.optimizer == 'lion': | ||
| 682 | try: | ||
| 683 | import lion_pytorch | ||
| 684 | except ImportError: | ||
| 685 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
| 686 | |||
| 687 | create_optimizer = partial( | ||
| 688 | lion_pytorch.Lion, | ||
| 689 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 690 | weight_decay=args.adam_weight_decay, | ||
| 691 | use_triton=True, | ||
| 692 | ) | ||
| 669 | elif args.optimizer == 'adafactor': | 693 | elif args.optimizer == 'adafactor': |
| 670 | create_optimizer = partial( | 694 | create_optimizer = partial( |
| 671 | transformers.optimization.Adafactor, | 695 | transformers.optimization.Adafactor, |
