summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
committerVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
commit2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch)
treee08741c9df3b30a05ade472da45d7410bbf972ae /train_ti.py
parentAdded Lion optimizer (diff)
downloadtextual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index 507d710..12e3644 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -292,7 +292,7 @@ def parse_args():
292 parser.add_argument( 292 parser.add_argument(
293 "--optimizer", 293 "--optimizer",
294 type=str, 294 type=str,
295 default="lion", 295 default="adam",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]' 296 help='Optimizer to use ["adam", "adam8bit", "lion"]'
297 ) 297 )
298 parser.add_argument( 298 parser.add_argument(
@@ -586,13 +586,15 @@ def main():
586 eps=args.adam_epsilon, 586 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad, 587 amsgrad=args.adam_amsgrad,
588 ) 588 )
589 else: 589 elif args.optimizer == 'lion':
590 try: 590 try:
591 from lion_pytorch import Lion 591 from lion_pytorch import Lion
592 except ImportError: 592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594 594
595 create_optimizer = partial(Lion, use_triton=True) 595 create_optimizer = partial(Lion, use_triton=True)
596 else:
597 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
596 598
597 checkpoint_output_dir = output_dir/"checkpoints" 599 checkpoint_output_dir = output_dir/"checkpoints"
598 600