From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 1 Apr 2023 12:35:43 +0200
Subject: Add support for Adafactor, add TI initializer noise

---
 models/clip/embeddings.py | 10 +++++++++-
 train_dreambooth.py       | 16 +++++++++++++++-
 train_lora.py             | 16 +++++++++++++++-
 train_ti.py               | 25 +++++++++++++++++++++++--
 training/functional.py    |  3 ++-
 training/optimization.py  |  3 +++
 6 files changed, 67 insertions(+), 6 deletions(-)

diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 4166dc6..9abd1bb 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
         self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
 
-    def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
+    def add_embed(
+        self,
+        token_ids: Union[int, list[int]],
+        initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None,
+        initializer_noise: float = 0.0,
+    ):
         if isinstance(token_ids, int):
             token_ids = [token_ids]
 
@@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
             dtype=self.temp_token_embedding.weight.dtype,
         )
 
+        if initializer_noise != 0:
+            initializer += torch.randn_like(initializer) * initializer_noise
+
         token_ids = torch.tensor(token_ids, dtype=torch.long)
 
         self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 3a25efa..4456bd1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -13,6 +13,7 @@ from accelerate import Accelerator
 from accelerate.logging import get_logger
 from accelerate.utils import LoggerType, set_seed
 from slugify import slugify
+import transformers
 
 from util.files import load_config, load_embeddings_from_dir
 from data.csv import VlpnDataModule, keyword_filter
@@ -305,7 +306,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
+        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -535,6 +536,19 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'adafactor':
+        create_optimizer = partial(
+            transformers.optimization.Adafactor,
+            beta1=args.adam_beta1,
+            weight_decay=args.adam_weight_decay,
+            scale_parameter=True,
+            relative_step=True,
+            warmup_init=True,
+        )
+
+        args.lr_scheduler = "adafactor"
+        args.lr_min_lr = args.learning_rate
+        args.learning_rate = None
     elif args.optimizer == 'dadam':
         try:
             import dadaptation
diff --git a/train_lora.py b/train_lora.py
index f74a438..f8dccae 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger
 from accelerate.utils import LoggerType, set_seed
 from peft import LoraConfig, LoraModel
 from slugify import slugify
+import transformers
 
 from util.files import load_config, load_embeddings_from_dir
 from data.csv import VlpnDataModule, keyword_filter
@@ -317,7 +318,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
+        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -567,6 +568,19 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'adafactor':
+        create_optimizer = partial(
+            transformers.optimization.Adafactor,
+            beta1=args.adam_beta1,
+            weight_decay=args.adam_weight_decay,
+            scale_parameter=True,
+            relative_step=True,
+            warmup_init=True,
+        )
+
+        args.lr_scheduler = "adafactor"
+        args.lr_min_lr = args.learning_rate
+        args.learning_rate = None
     elif args.optimizer == 'dadam':
         try:
             import dadaptation
diff --git a/train_ti.py b/train_ti.py
index dd015f9..274a1ca 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,6 +12,7 @@ from accelerate import Accelerator
 from accelerate.logging import get_logger
 from accelerate.utils import LoggerType, set_seed
 from slugify import slugify
+import transformers
 
 from util.files import load_config, load_embeddings_from_dir
 from data.csv import VlpnDataModule, keyword_filter
@@ -74,6 +75,12 @@ def parse_args():
         nargs='*',
         help="A token to use as initializer word."
     )
+    parser.add_argument(
+        "--initializer_noise",
+        type=float,
+        default=0,
+        help="Noise to apply to the initializer word"
+    )
     parser.add_argument(
         "--alias_tokens",
         type=str,
@@ -323,7 +330,7 @@ def parse_args():
         "--optimizer",
         type=str,
         default="dadan",
-        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
+        help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
     )
     parser.add_argument(
         "--dadaptation_d0",
@@ -659,6 +666,19 @@ def main():
             eps=args.adam_epsilon,
             amsgrad=args.adam_amsgrad,
         )
+    elif args.optimizer == 'adafactor':
+        create_optimizer = partial(
+            transformers.optimization.Adafactor,
+            beta1=args.adam_beta1,
+            weight_decay=args.adam_weight_decay,
+            scale_parameter=True,
+            relative_step=True,
+            warmup_init=True,
+        )
+
+        args.lr_scheduler = "adafactor"
+        args.lr_min_lr = args.learning_rate
+        args.learning_rate = None
     elif args.optimizer == 'dadam':
         try:
             import dadaptation
@@ -739,7 +759,8 @@ def main():
             embeddings=embeddings,
             placeholder_tokens=placeholder_tokens,
             initializer_tokens=initializer_tokens,
-            num_vectors=num_vectors
+            num_vectors=num_vectors,
+            initializer_noise=args.initializer_noise,
         )
 
         stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
diff --git a/training/functional.py b/training/functional.py
index a2aa24e..ac43847 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,6 +231,7 @@ def add_placeholder_tokens(
     placeholder_tokens: list[str],
     initializer_tokens: list[str],
     num_vectors: Optional[Union[list[int], int]] = None,
+    initializer_noise: float = 0.0,
 ):
     initializer_token_ids = [
         tokenizer.encode(token, add_special_tokens=False)
@@ -245,7 +246,7 @@ def add_placeholder_tokens(
     embeddings.resize(len(tokenizer))
 
     for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
-        embeddings.add_embed(placeholder_token_id, initializer_token_id)
+        embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise)
 
     return placeholder_token_ids, initializer_token_ids
 
diff --git a/training/optimization.py b/training/optimization.py
index 59ca950..53d0a6d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,6 +6,7 @@ import torch
 from torch.optim.lr_scheduler import LambdaLR
 
 from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
+import transformers
 
 
 class OneCyclePhase(NamedTuple):
@@ -148,6 +149,8 @@ def get_scheduler(
             num_training_steps=num_training_steps,
             num_cycles=cycles,
         )
+    elif id == "adafactor":
+        lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr)
     else:
         lr_scheduler = get_scheduler_(
             id,
-- 
cgit v1.2.3-70-g09d2