summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
committerVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
commitf894dfecfaa3ec17903b2ac37ac4f071408613db (patch)
tree02bf8439315c832528651186285f8b1fbd649f32 /train_lora.py
parentInference script: Better scheduler config (diff)
downloadtextual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip
Added Lion optimizer
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py38
1 files changed, 27 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py
index 330bcd6..368c29b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -245,9 +245,10 @@ def parse_args():
245 help="Minimum learning rate in the lr scheduler." 245 help="Minimum learning rate in the lr scheduler."
246 ) 246 )
247 parser.add_argument( 247 parser.add_argument(
248 "--use_8bit_adam", 248 "--optimizer",
249 action="store_true", 249 type=str,
250 help="Whether or not to use 8-bit Adam from bitsandbytes." 250 default="lion",
251 help='Optimizer to use ["adam", "adam8bit", "lion"]'
251 ) 252 )
252 parser.add_argument( 253 parser.add_argument(
253 "--adam_beta1", 254 "--adam_beta1",
@@ -466,15 +467,34 @@ def main():
466 args.learning_rate = 1e-6 467 args.learning_rate = 1e-6
467 args.lr_scheduler = "exponential_growth" 468 args.lr_scheduler = "exponential_growth"
468 469
469 if args.use_8bit_adam: 470 if args.optimizer == 'adam8bit':
470 try: 471 try:
471 import bitsandbytes as bnb 472 import bitsandbytes as bnb
472 except ImportError: 473 except ImportError:
473 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 474 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
474 475
475 optimizer_class = bnb.optim.AdamW8bit 476 create_optimizer = partial(
477 bnb.optim.AdamW8bit,
478 betas=(args.adam_beta1, args.adam_beta2),
479 weight_decay=args.adam_weight_decay,
480 eps=args.adam_epsilon,
481 amsgrad=args.adam_amsgrad,
482 )
483 elif args.optimizer == 'adam':
484 create_optimizer = partial(
485 torch.optim.AdamW,
486 betas=(args.adam_beta1, args.adam_beta2),
487 weight_decay=args.adam_weight_decay,
488 eps=args.adam_epsilon,
489 amsgrad=args.adam_amsgrad,
490 )
476 else: 491 else:
477 optimizer_class = torch.optim.AdamW 492 try:
493 from lion_pytorch import Lion
494 except ImportError:
495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
496
497 create_optimizer = partial(Lion, use_triton=True)
478 498
479 trainer = partial( 499 trainer = partial(
480 train, 500 train,
@@ -516,13 +536,9 @@ def main():
516 ) 536 )
517 datamodule.setup() 537 datamodule.setup()
518 538
519 optimizer = optimizer_class( 539 optimizer = create_optimizer(
520 lora_layers.parameters(), 540 lora_layers.parameters(),
521 lr=args.learning_rate, 541 lr=args.learning_rate,
522 betas=(args.adam_beta1, args.adam_beta2),
523 weight_decay=args.adam_weight_decay,
524 eps=args.adam_epsilon,
525 amsgrad=args.adam_amsgrad,
526 ) 542 )
527 543
528 lr_scheduler = get_scheduler( 544 lr_scheduler = get_scheduler(