summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py3
-rw-r--r--training/optimization.py3
2 files changed, 5 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index a2aa24e..ac43847 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,6 +231,7 @@ def add_placeholder_tokens(
231 placeholder_tokens: list[str], 231 placeholder_tokens: list[str],
232 initializer_tokens: list[str], 232 initializer_tokens: list[str],
233 num_vectors: Optional[Union[list[int], int]] = None, 233 num_vectors: Optional[Union[list[int], int]] = None,
234 initializer_noise: float = 0.0,
234): 235):
235 initializer_token_ids = [ 236 initializer_token_ids = [
236 tokenizer.encode(token, add_special_tokens=False) 237 tokenizer.encode(token, add_special_tokens=False)
@@ -245,7 +246,7 @@ def add_placeholder_tokens(
245 embeddings.resize(len(tokenizer)) 246 embeddings.resize(len(tokenizer))
246 247
247 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
248 embeddings.add_embed(placeholder_token_id, initializer_token_id) 249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise)
249 250
250 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
251 252
diff --git a/training/optimization.py b/training/optimization.py
index 59ca950..53d0a6d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,6 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers
9 10
10 11
11class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -148,6 +149,8 @@ def get_scheduler(
148 num_training_steps=num_training_steps, 149 num_training_steps=num_training_steps,
149 num_cycles=cycles, 150 num_cycles=cycles,
150 ) 151 )
152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr)
151 else: 154 else:
152 lr_scheduler = get_scheduler_( 155 lr_scheduler = get_scheduler_(
153 id, 156 id,