From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- models/clip/embeddings.py | 10 +++++++++- train_dreambooth.py | 16 +++++++++++++++- train_lora.py | 16 +++++++++++++++- train_ti.py | 25 +++++++++++++++++++++++-- training/functional.py | 3 ++- training/optimization.py | 3 +++ 6 files changed, 67 insertions(+), 6 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4166dc6..9abd1bb 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) - def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): + def add_embed( + self, + token_ids: Union[int, list[int]], + initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, + initializer_noise: float = 0.0, + ): if isinstance(token_ids, int): token_ids = [token_ids] @@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): dtype=self.temp_token_embedding.weight.dtype, ) + if initializer_noise != 0: + initializer += torch.randn_like(initializer) * initializer_noise + token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -13,6 +13,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -305,7 +306,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -535,6 +536,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py @@ -14,6 +14,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -317,7 +318,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -567,6 +568,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py @@ -12,6 +12,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify +import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -74,6 +75,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word" + ) parser.add_argument( "--alias_tokens", type=str, @@ -323,7 +330,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -659,6 +666,19 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adafactor': + create_optimizer = partial( + transformers.optimization.Adafactor, + beta1=args.adam_beta1, + weight_decay=args.adam_weight_decay, + scale_parameter=True, + relative_step=True, + warmup_init=True, + ) + + args.lr_scheduler = "adafactor" + args.lr_min_lr = args.learning_rate + args.learning_rate = None elif args.optimizer == 'dadam': try: import dadaptation @@ -739,7 +759,8 @@ def main(): embeddings=embeddings, placeholder_tokens=placeholder_tokens, initializer_tokens=initializer_tokens, - num_vectors=num_vectors + num_vectors=num_vectors, + initializer_noise=args.initializer_noise, ) stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py @@ -231,6 +231,7 @@ def add_placeholder_tokens( placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Optional[Union[list[int], int]] = None, + initializer_noise: float = 0.0, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) @@ -245,7 +246,7 @@ def add_placeholder_tokens( embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id) + embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) return placeholder_token_ids, initializer_token_ids diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,6 +6,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup +import transformers class OneCyclePhase(NamedTuple): @@ -148,6 +149,8 @@ def get_scheduler( num_training_steps=num_training_steps, num_cycles=cycles, ) + elif id == "adafactor": + lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) else: lr_scheduler = get_scheduler_( id, -- cgit v1.2.3-54-g00ecf