diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 42 |
1 files changed, 4 insertions, 38 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index c9c788c..1e21965 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
31 | return new_embedding | 31 | return new_embedding |
32 | 32 | ||
33 | 33 | ||
34 | class OverlayLinear(nn.Module): | ||
35 | def __init__(self, in_features, out_features, rank=4): | ||
36 | super().__init__() | ||
37 | |||
38 | if rank > min(in_features, out_features): | ||
39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
40 | |||
41 | self.rank = rank | ||
42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
44 | self.reset() | ||
45 | |||
46 | def reset(self): | ||
47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
48 | nn.init.zeros_(self.up.weight) | ||
49 | |||
50 | def forward(self, hidden_states): | ||
51 | orig_dtype = hidden_states.dtype | ||
52 | dtype = self.down.weight.dtype | ||
53 | |||
54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
55 | up_hidden_states = self.up(down_hidden_states) | ||
56 | |||
57 | return up_hidden_states.to(orig_dtype) | ||
58 | |||
59 | |||
60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): | 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): |
62 | super().__init__(config) | 36 | super().__init__(config) |
63 | 37 | ||
64 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
65 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
66 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
41 | self.alpha = alpha | ||
67 | 42 | ||
68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | ||
69 | self.temp_token_embedding = nn.Embedding( | 43 | self.temp_token_embedding = nn.Embedding( |
70 | self.token_embedding.num_embeddings, | 44 | self.token_embedding.num_embeddings, |
71 | self.token_embedding.embedding_dim, | 45 | self.token_embedding.embedding_dim, |
@@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() |
76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
77 | 51 | ||
78 | def reset_overlay(self): | ||
79 | self.overlay.reset() | ||
80 | |||
81 | def resize(self, size: int): | 52 | def resize(self, size: int): |
82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 53 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 54 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
@@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
125 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
126 | 97 | ||
127 | def persist(self): | 98 | def persist(self): |
128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] | 99 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) | ||
130 | self.overlay.reset() | ||
131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 100 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
132 | 101 | ||
133 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 102 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
@@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 104 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
136 | 105 | ||
137 | embeds = self.token_embedding(input_ids) | 106 | embeds = self.token_embedding(input_ids) |
138 | |||
139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
140 | 108 | embeds[mask] = self.temp_token_embedding(input_ids[mask]) | |
141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
143 | 109 | ||
144 | return embeds | 110 | return embeds |
145 | 111 | ||