summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
commit01eee0cb24f52ca78761b78917959e1c247eae94 (patch)
tree914c0d3f5b888a4c344b30a861639c8e3d5259dd /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index a2aa24e..ac43847 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,6 +231,7 @@ def add_placeholder_tokens(
231 placeholder_tokens: list[str], 231 placeholder_tokens: list[str],
232 initializer_tokens: list[str], 232 initializer_tokens: list[str],
233 num_vectors: Optional[Union[list[int], int]] = None, 233 num_vectors: Optional[Union[list[int], int]] = None,
234 initializer_noise: float = 0.0,
234): 235):
235 initializer_token_ids = [ 236 initializer_token_ids = [
236 tokenizer.encode(token, add_special_tokens=False) 237 tokenizer.encode(token, add_special_tokens=False)
@@ -245,7 +246,7 @@ def add_placeholder_tokens(
245 embeddings.resize(len(tokenizer)) 246 embeddings.resize(len(tokenizer))
246 247
247 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
248 embeddings.add_embed(placeholder_token_id, initializer_token_id) 249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise)
249 250
250 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
251 252