From 01eee0cb24f52ca78761b78917959e1c247eae94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 12:35:43 +0200 Subject: Add support for Adafactor, add TI initializer noise --- models/clip/embeddings.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4166dc6..9abd1bb 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) - def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): + def add_embed( + self, + token_ids: Union[int, list[int]], + initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, + initializer_noise: float = 0.0, + ): if isinstance(token_ids, int): token_ids = [token_ids] @@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): dtype=self.temp_token_embedding.weight.dtype, ) + if initializer_noise != 0: + initializer += torch.randn_like(initializer) * initializer_noise + token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) -- cgit v1.2.3-54-g00ecf