summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml1
-rw-r--r--train_dreambooth.py30
-rw-r--r--train_lora.py30
-rw-r--r--train_ti.py30
4 files changed, 82 insertions, 9 deletions
diff --git a/environment.yaml b/environment.yaml
index 8868532..1de76bd 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -18,6 +18,7 @@ dependencies:
18 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 18 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
19 - accelerate==0.17.1 19 - accelerate==0.17.1
20 - bitsandbytes==0.37.2 20 - bitsandbytes==0.37.2
21 - lion-pytorch==0.0.7
21 - peft==0.2.0 22 - peft==0.2.0
22 - python-slugify>=6.1.2 23 - python-slugify>=6.1.2
23 - safetensors==0.3.0 24 - safetensors==0.3.0
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48b7926..be7d6fe 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -306,7 +306,7 @@ def parse_args():
306 "--optimizer", 306 "--optimizer",
307 type=str, 307 type=str,
308 default="dadan", 308 default="dadan",
309 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 309 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
310 ) 310 )
311 parser.add_argument( 311 parser.add_argument(
312 "--dadaptation_d0", 312 "--dadaptation_d0",
@@ -317,13 +317,13 @@ def parse_args():
317 parser.add_argument( 317 parser.add_argument(
318 "--adam_beta1", 318 "--adam_beta1",
319 type=float, 319 type=float,
320 default=0.9, 320 default=None,
321 help="The beta1 parameter for the Adam optimizer." 321 help="The beta1 parameter for the Adam optimizer."
322 ) 322 )
323 parser.add_argument( 323 parser.add_argument(
324 "--adam_beta2", 324 "--adam_beta2",
325 type=float, 325 type=float,
326 default=0.999, 326 default=None,
327 help="The beta2 parameter for the Adam optimizer." 327 help="The beta2 parameter for the Adam optimizer."
328 ) 328 )
329 parser.add_argument( 329 parser.add_argument(
@@ -450,6 +450,18 @@ def parse_args():
450 if args.output_dir is None: 450 if args.output_dir is None:
451 raise ValueError("You must specify --output_dir") 451 raise ValueError("You must specify --output_dir")
452 452
453 if args.adam_beta1 is None:
454 if args.optimizer in ('adam', 'adam8bit'):
455 args.adam_beta1 = 0.9
456 elif args.optimizer == 'lion':
457 args.adam_beta1 = 0.95
458
459 if args.adam_beta2 is None:
460 if args.optimizer in ('adam', 'adam8bit'):
461 args.adam_beta2 = 0.999
462 elif args.optimizer == 'lion':
463 args.adam_beta2 = 0.98
464
453 return args 465 return args
454 466
455 467
@@ -536,6 +548,18 @@ def main():
536 eps=args.adam_epsilon, 548 eps=args.adam_epsilon,
537 amsgrad=args.adam_amsgrad, 549 amsgrad=args.adam_amsgrad,
538 ) 550 )
551 elif args.optimizer == 'lion':
552 try:
553 import lion_pytorch
554 except ImportError:
555 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
556
557 create_optimizer = partial(
558 lion_pytorch.Lion,
559 betas=(args.adam_beta1, args.adam_beta2),
560 weight_decay=args.adam_weight_decay,
561 use_triton=True,
562 )
539 elif args.optimizer == 'adafactor': 563 elif args.optimizer == 'adafactor':
540 create_optimizer = partial( 564 create_optimizer = partial(
541 transformers.optimization.Adafactor, 565 transformers.optimization.Adafactor,
diff --git a/train_lora.py b/train_lora.py
index cf73645..a0cd174 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -318,7 +318,7 @@ def parse_args():
318 "--optimizer", 318 "--optimizer",
319 type=str, 319 type=str,
320 default="dadan", 320 default="dadan",
321 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 321 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
322 ) 322 )
323 parser.add_argument( 323 parser.add_argument(
324 "--dadaptation_d0", 324 "--dadaptation_d0",
@@ -329,13 +329,13 @@ def parse_args():
329 parser.add_argument( 329 parser.add_argument(
330 "--adam_beta1", 330 "--adam_beta1",
331 type=float, 331 type=float,
332 default=0.9, 332 default=None,
333 help="The beta1 parameter for the Adam optimizer." 333 help="The beta1 parameter for the Adam optimizer."
334 ) 334 )
335 parser.add_argument( 335 parser.add_argument(
336 "--adam_beta2", 336 "--adam_beta2",
337 type=float, 337 type=float,
338 default=0.999, 338 default=None,
339 help="The beta2 parameter for the Adam optimizer." 339 help="The beta2 parameter for the Adam optimizer."
340 ) 340 )
341 parser.add_argument( 341 parser.add_argument(
@@ -468,6 +468,18 @@ def parse_args():
468 if args.output_dir is None: 468 if args.output_dir is None:
469 raise ValueError("You must specify --output_dir") 469 raise ValueError("You must specify --output_dir")
470 470
471 if args.adam_beta1 is None:
472 if args.optimizer in ('adam', 'adam8bit'):
473 args.adam_beta1 = 0.9
474 elif args.optimizer == 'lion':
475 args.adam_beta1 = 0.95
476
477 if args.adam_beta2 is None:
478 if args.optimizer in ('adam', 'adam8bit'):
479 args.adam_beta2 = 0.999
480 elif args.optimizer == 'lion':
481 args.adam_beta2 = 0.98
482
471 return args 483 return args
472 484
473 485
@@ -568,6 +580,18 @@ def main():
568 eps=args.adam_epsilon, 580 eps=args.adam_epsilon,
569 amsgrad=args.adam_amsgrad, 581 amsgrad=args.adam_amsgrad,
570 ) 582 )
583 elif args.optimizer == 'lion':
584 try:
585 import lion_pytorch
586 except ImportError:
587 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
588
589 create_optimizer = partial(
590 lion_pytorch.Lion,
591 betas=(args.adam_beta1, args.adam_beta2),
592 weight_decay=args.adam_weight_decay,
593 use_triton=True,
594 )
571 elif args.optimizer == 'adafactor': 595 elif args.optimizer == 'adafactor':
572 create_optimizer = partial( 596 create_optimizer = partial(
573 transformers.optimization.Adafactor, 597 transformers.optimization.Adafactor,
diff --git a/train_ti.py b/train_ti.py
index 651dfbe..c242625 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -330,7 +330,7 @@ def parse_args():
330 "--optimizer", 330 "--optimizer",
331 type=str, 331 type=str,
332 default="dadan", 332 default="dadan",
333 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 333 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
334 ) 334 )
335 parser.add_argument( 335 parser.add_argument(
336 "--dadaptation_d0", 336 "--dadaptation_d0",
@@ -341,13 +341,13 @@ def parse_args():
341 parser.add_argument( 341 parser.add_argument(
342 "--adam_beta1", 342 "--adam_beta1",
343 type=float, 343 type=float,
344 default=0.9, 344 default=None,
345 help="The beta1 parameter for the Adam optimizer." 345 help="The beta1 parameter for the Adam optimizer."
346 ) 346 )
347 parser.add_argument( 347 parser.add_argument(
348 "--adam_beta2", 348 "--adam_beta2",
349 type=float, 349 type=float,
350 default=0.999, 350 default=None,
351 help="The beta2 parameter for the Adam optimizer." 351 help="The beta2 parameter for the Adam optimizer."
352 ) 352 )
353 parser.add_argument( 353 parser.add_argument(
@@ -566,6 +566,18 @@ def parse_args():
566 if args.output_dir is None: 566 if args.output_dir is None:
567 raise ValueError("You must specify --output_dir") 567 raise ValueError("You must specify --output_dir")
568 568
569 if args.adam_beta1 is None:
570 if args.optimizer in ('adam', 'adam8bit'):
571 args.adam_beta1 = 0.9
572 elif args.optimizer == 'lion':
573 args.adam_beta1 = 0.95
574
575 if args.adam_beta2 is None:
576 if args.optimizer in ('adam', 'adam8bit'):
577 args.adam_beta2 = 0.999
578 elif args.optimizer == 'lion':
579 args.adam_beta2 = 0.98
580
569 return args 581 return args
570 582
571 583
@@ -666,6 +678,18 @@ def main():
666 eps=args.adam_epsilon, 678 eps=args.adam_epsilon,
667 amsgrad=args.adam_amsgrad, 679 amsgrad=args.adam_amsgrad,
668 ) 680 )
681 elif args.optimizer == 'lion':
682 try:
683 import lion_pytorch
684 except ImportError:
685 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
686
687 create_optimizer = partial(
688 lion_pytorch.Lion,
689 betas=(args.adam_beta1, args.adam_beta2),
690 weight_decay=args.adam_weight_decay,
691 use_triton=True,
692 )
669 elif args.optimizer == 'adafactor': 693 elif args.optimizer == 'adafactor':
670 create_optimizer = partial( 694 create_optimizer = partial(
671 transformers.optimization.Adafactor, 695 transformers.optimization.Adafactor,