diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 3 | ||||
-rw-r--r-- | training/optimization.py | 3 |
2 files changed, 5 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -231,6 +231,7 @@ def add_placeholder_tokens( | |||
231 | placeholder_tokens: list[str], | 231 | placeholder_tokens: list[str], |
232 | initializer_tokens: list[str], | 232 | initializer_tokens: list[str], |
233 | num_vectors: Optional[Union[list[int], int]] = None, | 233 | num_vectors: Optional[Union[list[int], int]] = None, |
234 | initializer_noise: float = 0.0, | ||
234 | ): | 235 | ): |
235 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
236 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
@@ -245,7 +246,7 @@ def add_placeholder_tokens( | |||
245 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
246 | 247 | ||
247 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
248 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) |
249 | 250 | ||
250 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
251 | 252 | ||
diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,6 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | ||
9 | 10 | ||
10 | 11 | ||
11 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -148,6 +149,8 @@ def get_scheduler( | |||
148 | num_training_steps=num_training_steps, | 149 | num_training_steps=num_training_steps, |
149 | num_cycles=cycles, | 150 | num_cycles=cycles, |
150 | ) | 151 | ) |
152 | elif id == "adafactor": | ||
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | ||
151 | else: | 154 | else: |
152 | lr_scheduler = get_scheduler_( | 155 | lr_scheduler = get_scheduler_( |
153 | id, | 156 | id, |