From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- training/functional.py | 3 ++- training/optimization.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py @@ -231,6 +231,7 @@ def add_placeholder_tokens( placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Optional[Union[list[int], int]] = None, + initializer_noise: float = 0.0, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) @@ -245,7 +246,7 @@ def add_placeholder_tokens( embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id) + embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) return placeholder_token_ids, initializer_token_ids diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,6 +6,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup +import transformers class OneCyclePhase(NamedTuple): @@ -148,6 +149,8 @@ def get_scheduler( num_training_steps=num_training_steps, num_cycles=cycles, ) + elif id == "adafactor": + lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) else: lr_scheduler = get_scheduler_( id, -- cgit v1.2.3-54-g00ecf