From a08cd4b1581ca195f619e8bdb6cb6448287d4d2f Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 3 Apr 2023 12:39:17 +0200
Subject: Bring back Lion optimizer

---
 environment.yaml    |  1 +
 train_dreambooth.py | 30 +++++++++++++++++++++++++++---
 train_lora.py       | 30 +++++++++++++++++++++++++++---
 train_ti.py         | 30 +++++++++++++++++++++++++++---
 4 files changed, 82 insertions(+), 9 deletions(-)

diff --git a/environment.yaml b/environment.yaml
index 8868532..1de76bd 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -18,6 +18,7 @@ dependencies:
           - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
           - accelerate==0.17.1
           - bitsandbytes==0.37.2
+          - lion-pytorch==0.0.7
           - peft==0.2.0
           - python-slugify>=6.1.2
           - safetensors==0.3.0
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48b7926..be7d6fe 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -306,7 +306,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
+        help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -317,13 +317,13 @@ def parse_args():
     parser.add_argument(
         "--adam_beta1",
         type=float,
-        default=0.9,
+        default=None,
         help="The beta1 parameter for the Adam optimizer."
     )
     parser.add_argument(
         "--adam_beta2",
         type=float,
-        default=0.999,
+        default=None,
         help="The beta2 parameter for the Adam optimizer."
     )
     parser.add_argument(
@@ -450,6 +450,18 @@ def parse_args():
     if args.output_dir is None:
         raise ValueError("You must specify --output_dir")
 
+    if args.adam_beta1 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta1 = 0.9
+        elif args.optimizer == 'lion':
+            args.adam_beta1 = 0.95
+
+    if args.adam_beta2 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta2 = 0.999
+        elif args.optimizer == 'lion':
+            args.adam_beta2 = 0.98
+
     return args
 
 
@@ -536,6 +548,18 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'lion':
+        try:
+            import lion_pytorch
+        except ImportError:
+            raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
+
+        create_optimizer = partial(
+            lion_pytorch.Lion,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            use_triton=True,
+        )
     elif args.optimizer == 'adafactor':
         create_optimizer = partial(
             transformers.optimization.Adafactor,
diff --git a/train_lora.py b/train_lora.py
index cf73645..a0cd174 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -318,7 +318,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
+        help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -329,13 +329,13 @@ def parse_args():
     parser.add_argument(
         "--adam_beta1",
         type=float,
-        default=0.9,
+        default=None,
         help="The beta1 parameter for the Adam optimizer."
     )
     parser.add_argument(
         "--adam_beta2",
         type=float,
-        default=0.999,
+        default=None,
         help="The beta2 parameter for the Adam optimizer."
     )
     parser.add_argument(
@@ -468,6 +468,18 @@ def parse_args():
     if args.output_dir is None:
         raise ValueError("You must specify --output_dir")
 
+    if args.adam_beta1 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta1 = 0.9
+        elif args.optimizer == 'lion':
+            args.adam_beta1 = 0.95
+
+    if args.adam_beta2 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta2 = 0.999
+        elif args.optimizer == 'lion':
+            args.adam_beta2 = 0.98
+
     return args
 
 
@@ -568,6 +580,18 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'lion':
+        try:
+            import lion_pytorch
+        except ImportError:
+            raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
+
+        create_optimizer = partial(
+            lion_pytorch.Lion,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            use_triton=True,
+        )
     elif args.optimizer == 'adafactor':
         create_optimizer = partial(
             transformers.optimization.Adafactor,
diff --git a/train_ti.py b/train_ti.py
index 651dfbe..c242625 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -330,7 +330,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
+        help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -341,13 +341,13 @@ def parse_args():
     parser.add_argument(
         "--adam_beta1",
         type=float,
-        default=0.9,
+        default=None,
         help="The beta1 parameter for the Adam optimizer."
     )
     parser.add_argument(
         "--adam_beta2",
         type=float,
-        default=0.999,
+        default=None,
         help="The beta2 parameter for the Adam optimizer."
     )
     parser.add_argument(
@@ -566,6 +566,18 @@ def parse_args():
     if args.output_dir is None:
         raise ValueError("You must specify --output_dir")
 
+    if args.adam_beta1 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta1 = 0.9
+        elif args.optimizer == 'lion':
+            args.adam_beta1 = 0.95
+
+    if args.adam_beta2 is None:
+        if args.optimizer in ('adam', 'adam8bit'):
+            args.adam_beta2 = 0.999
+        elif args.optimizer == 'lion':
+            args.adam_beta2 = 0.98
+
     return args
 
 
@@ -666,6 +678,18 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'lion':
+        try:
+            import lion_pytorch
+        except ImportError:
+            raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
+
+        create_optimizer = partial(
+            lion_pytorch.Lion,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            use_triton=True,
+        )
     elif args.optimizer == 'adafactor':
         create_optimizer = partial(
             transformers.optimization.Adafactor,
-- 
cgit v1.2.3-70-g09d2