diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
commit | 01eee0cb24f52ca78761b78917959e1c247eae94 (patch) | |
tree | 914c0d3f5b888a4c344b30a861639c8e3d5259dd /models | |
parent | Update (diff) | |
download | textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2 textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip |
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4166dc6..9abd1bb 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed( |
56 | self, | ||
57 | token_ids: Union[int, list[int]], | ||
58 | initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, | ||
59 | initializer_noise: float = 0.0, | ||
60 | ): | ||
56 | if isinstance(token_ids, int): | 61 | if isinstance(token_ids, int): |
57 | token_ids = [token_ids] | 62 | token_ids = [token_ids] |
58 | 63 | ||
@@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
73 | dtype=self.temp_token_embedding.weight.dtype, | 78 | dtype=self.temp_token_embedding.weight.dtype, |
74 | ) | 79 | ) |
75 | 80 | ||
81 | if initializer_noise != 0: | ||
82 | initializer += torch.randn_like(initializer) * initializer_noise | ||
83 | |||
76 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 84 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
77 | 85 | ||
78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |