diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 50 |
1 files changed, 33 insertions, 17 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1e21965..d8343a0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -12,7 +12,7 @@ from transformers.models.clip import CLIPTextConfig | |||
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | 14 | ||
15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: | 15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: |
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape |
17 | 17 | ||
18 | if old_num_embeddings == new_num_embeddings: | 18 | if old_num_embeddings == new_num_embeddings: |
@@ -26,13 +26,16 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
26 | device=old_embedding.weight.device, | 26 | device=old_embedding.weight.device, |
27 | dtype=old_embedding.weight.dtype | 27 | dtype=old_embedding.weight.dtype |
28 | ) | 28 | ) |
29 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 29 | if initializer_factor is not None: |
30 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
31 | else: | ||
32 | nn.init.zeros_(new_embedding.weight.data) | ||
30 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | 33 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] |
31 | return new_embedding | 34 | return new_embedding |
32 | 35 | ||
33 | 36 | ||
34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 37 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): | 38 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): |
36 | super().__init__(config) | 39 | super().__init__(config) |
37 | 40 | ||
38 | self.token_embedding = embeddings.token_embedding | 41 | self.token_embedding = embeddings.token_embedding |
@@ -40,17 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.initializer_factor = config.initializer_factor | 43 | self.initializer_factor = config.initializer_factor |
41 | self.alpha = alpha | 44 | self.alpha = alpha |
42 | 45 | ||
43 | self.temp_token_embedding = nn.Embedding( | 46 | self.temp_token_embedding = nn.ParameterList() |
44 | self.token_embedding.num_embeddings, | ||
45 | self.token_embedding.embedding_dim, | ||
46 | device=self.token_embedding.weight.device, | ||
47 | dtype=self.token_embedding.weight.dtype | ||
48 | ) | ||
49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 47 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
51 | 48 | ||
52 | def resize(self, size: int): | 49 | def resize(self, size: int): |
53 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 50 | for _ in range(len(self.temp_token_embedding), size): |
51 | self.temp_token_embedding.append(torch.zeros( | ||
52 | self.token_embedding.embedding_dim, | ||
53 | device=self.token_embedding.weight.device, | ||
54 | dtype=self.token_embedding.weight.dtype, | ||
55 | )) | ||
54 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
55 | 57 | ||
56 | def add_embed( | 58 | def add_embed( |
@@ -85,7 +87,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
85 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
86 | 88 | ||
87 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 89 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
88 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 90 | self.token_embedding.weight.data[token_ids] = initializer |
90 | 91 | ||
91 | def load_embed(self, input_ids: list[int], filename: Path): | 92 | def load_embed(self, input_ids: list[int], filename: Path): |
@@ -96,16 +97,31 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
96 | save_file({"embed": self.get_embed(input_ids)}, filename) | 97 | save_file({"embed": self.get_embed(input_ids)}, filename) |
97 | 98 | ||
98 | def persist(self): | 99 | def persist(self): |
99 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 100 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): |
101 | self.token_embedding.weight.data[id] += self.alpha * emb | ||
102 | nn.init.zeros_(emb) | ||
100 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 103 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
101 | 104 | ||
102 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 105 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
103 | if isinstance(input_ids, list): | 106 | if isinstance(input_ids, list): |
104 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 107 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
105 | 108 | ||
109 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
110 | |||
106 | embeds = self.token_embedding(input_ids) | 111 | embeds = self.token_embedding(input_ids) |
107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 112 | mask = torch.isin(input_ids, all_temp_token_ids) |
108 | embeds[mask] = self.temp_token_embedding(input_ids[mask]) | 113 | temp_token_ids = input_ids[mask] |
114 | |||
115 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
116 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
117 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
118 | |||
119 | if len(temp_token_ids): | ||
120 | embeds_override = torch.stack([ | ||
121 | self.temp_token_embedding[id] | ||
122 | for id in temp_token_ids | ||
123 | ]) | ||
124 | embeds[mask] += self.alpha * embeds_override | ||
109 | 125 | ||
110 | return embeds | 126 | return embeds |
111 | 127 | ||
@@ -129,7 +145,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
129 | return embeddings | 145 | return embeddings |
130 | 146 | ||
131 | 147 | ||
132 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 148 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: |
133 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 149 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) |
134 | text_encoder.text_model.embeddings = text_embeddings | 150 | text_encoder.text_model.embeddings = text_embeddings |
135 | return text_embeddings | 151 | return text_embeddings |