diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 53 |
1 files changed, 38 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9abd1bb..88e0cc0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 31 | return new_embedding | 31 | return new_embedding |
| 32 | 32 | ||
| 33 | 33 | ||
| 34 | class OverlayLinear(nn.Module): | ||
| 35 | def __init__(self, in_features, out_features, rank=4): | ||
| 36 | super().__init__() | ||
| 37 | |||
| 38 | if rank > min(in_features, out_features): | ||
| 39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
| 40 | |||
| 41 | self.rank = rank | ||
| 42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
| 43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
| 44 | self.reset() | ||
| 45 | |||
| 46 | def reset(self): | ||
| 47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
| 48 | nn.init.zeros_(self.up.weight) | ||
| 49 | |||
| 50 | def forward(self, hidden_states): | ||
| 51 | orig_dtype = hidden_states.dtype | ||
| 52 | dtype = self.down.weight.dtype | ||
| 53 | |||
| 54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
| 55 | up_hidden_states = self.up(down_hidden_states) | ||
| 56 | |||
| 57 | return up_hidden_states.to(orig_dtype) | ||
| 58 | |||
| 59 | |||
| 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): |
| 36 | super().__init__(config) | 62 | super().__init__(config) |
| 37 | 63 | ||
| 38 | self.token_embedding = embeddings.token_embedding | 64 | self.token_embedding = embeddings.token_embedding |
| 39 | self.position_embedding = embeddings.position_embedding | 65 | self.position_embedding = embeddings.position_embedding |
| 40 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
| 41 | 67 | ||
| 42 | self.temp_token_embedding = nn.Embedding( | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
| 43 | self.token_embedding.num_embeddings, | ||
| 44 | self.token_embedding.embedding_dim, | ||
| 45 | device=self.token_embedding.weight.device, | ||
| 46 | dtype=self.token_embedding.weight.dtype | ||
| 47 | ) | ||
| 48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 50 | 70 | ||
| 71 | def reset_overlay(self): | ||
| 72 | self.overlay.reset() | ||
| 73 | |||
| 51 | def resize(self, size: int): | 74 | def resize(self, size: int): |
| 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 76 | ||
| 55 | def add_embed( | 77 | def add_embed( |
| @@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | initializer = self.get_embed(initializer) | 96 | initializer = self.get_embed(initializer) |
| 75 | 97 | ||
| 76 | initializer = initializer.to( | 98 | initializer = initializer.to( |
| 77 | device=self.temp_token_embedding.weight.device, | 99 | device=self.token_embedding.weight.device, |
| 78 | dtype=self.temp_token_embedding.weight.dtype, | 100 | dtype=self.token_embedding.weight.dtype, |
| 79 | ) | 101 | ) |
| 80 | 102 | ||
| 81 | if initializer_noise != 0: | 103 | if initializer_noise != 0: |
| @@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 84 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 106 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 85 | 107 | ||
| 86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 87 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 88 | self.token_embedding.weight.data[token_ids] = initializer | 109 | self.token_embedding.weight.data[token_ids] = initializer |
| 89 | 110 | ||
| 90 | def load_embed(self, input_ids: list[int], filename: Path): | 111 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 95 | save_file({"embed": self.get_embed(input_ids)}, filename) | 116 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 96 | 117 | ||
| 97 | def persist(self): | 118 | def persist(self): |
| 98 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( |
| 120 | self.token_embedding.weight.data[self.temp_token_ids] | ||
| 121 | ) | ||
| 122 | self.overlay.reset() | ||
| 99 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 100 | 124 | ||
| 101 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 125 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| @@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 103 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 104 | 128 | ||
| 105 | embeds = self.token_embedding(input_ids) | 129 | embeds = self.token_embedding(input_ids) |
| 106 | |||
| 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 108 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 131 | embeds[mask] += self.overlay(embeds[mask]) |
| 109 | 132 | ||
| 110 | return embeds | 133 | return embeds |
| 111 | 134 | ||
