summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
committerVolpeon <git@volpeon.ink>2023-02-17 21:06:11 +0100
commitf894dfecfaa3ec17903b2ac37ac4f071408613db (patch)
tree02bf8439315c832528651186285f8b1fbd649f32 /train_ti.py
parentInference script: Better scheduler config (diff)
downloadtextual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2
textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip
Added Lion optimizer
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py38
1 files changed, 27 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py
index 3aa1027..507d710 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -290,9 +290,10 @@ def parse_args():
290 default=0.9999 290 default=0.9999
291 ) 291 )
292 parser.add_argument( 292 parser.add_argument(
293 "--use_8bit_adam", 293 "--optimizer",
294 action="store_true", 294 type=str,
295 help="Whether or not to use 8-bit Adam from bitsandbytes." 295 default="lion",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]'
296 ) 297 )
297 parser.add_argument( 298 parser.add_argument(
298 "--adam_beta1", 299 "--adam_beta1",
@@ -564,15 +565,34 @@ def main():
564 args.learning_rate = 1e-5 565 args.learning_rate = 1e-5
565 args.lr_scheduler = "exponential_growth" 566 args.lr_scheduler = "exponential_growth"
566 567
567 if args.use_8bit_adam: 568 if args.optimizer == 'adam8bit':
568 try: 569 try:
569 import bitsandbytes as bnb 570 import bitsandbytes as bnb
570 except ImportError: 571 except ImportError:
571 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 572 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
572 573
573 optimizer_class = bnb.optim.AdamW8bit 574 create_optimizer = partial(
575 bnb.optim.AdamW8bit,
576 betas=(args.adam_beta1, args.adam_beta2),
577 weight_decay=args.adam_weight_decay,
578 eps=args.adam_epsilon,
579 amsgrad=args.adam_amsgrad,
580 )
581 elif args.optimizer == 'adam':
582 create_optimizer = partial(
583 torch.optim.AdamW,
584 betas=(args.adam_beta1, args.adam_beta2),
585 weight_decay=args.adam_weight_decay,
586 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad,
588 )
574 else: 589 else:
575 optimizer_class = torch.optim.AdamW 590 try:
591 from lion_pytorch import Lion
592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594
595 create_optimizer = partial(Lion, use_triton=True)
576 596
577 checkpoint_output_dir = output_dir/"checkpoints" 597 checkpoint_output_dir = output_dir/"checkpoints"
578 598
@@ -658,13 +678,9 @@ def main():
658 ) 678 )
659 datamodule.setup() 679 datamodule.setup()
660 680
661 optimizer = optimizer_class( 681 optimizer = create_optimizer(
662 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 682 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
663 lr=args.learning_rate, 683 lr=args.learning_rate,
664 betas=(args.adam_beta1, args.adam_beta2),
665 weight_decay=args.adam_weight_decay,
666 eps=args.adam_epsilon,
667 amsgrad=args.adam_amsgrad,
668 ) 684 )
669 685
670 lr_scheduler = get_scheduler( 686 lr_scheduler = get_scheduler(