diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -231,6 +231,7 @@ def add_placeholder_tokens( | |||
| 231 | placeholder_tokens: list[str], | 231 | placeholder_tokens: list[str], |
| 232 | initializer_tokens: list[str], | 232 | initializer_tokens: list[str], |
| 233 | num_vectors: Optional[Union[list[int], int]] = None, | 233 | num_vectors: Optional[Union[list[int], int]] = None, |
| 234 | initializer_noise: float = 0.0, | ||
| 234 | ): | 235 | ): |
| 235 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
| 236 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
| @@ -245,7 +246,7 @@ def add_placeholder_tokens( | |||
| 245 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
| 246 | 247 | ||
| 247 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
| 248 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) |
| 249 | 250 | ||
| 250 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
| 251 | 252 | ||
