diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6be6e9f..8d01867 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -38,18 +38,24 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
40 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
41 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | ||
42 | self.init_temp_embeddings() | ||
41 | 43 | ||
44 | def init_temp_embeddings(self): | ||
42 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.Embedding( |
43 | self.token_embedding.num_embeddings, | 46 | 0, |
44 | self.token_embedding.embedding_dim, | 47 | self.token_embedding.embedding_dim, |
45 | device=self.token_embedding.weight.device, | 48 | device=self.token_embedding.weight.device, |
46 | dtype=self.token_embedding.weight.dtype | 49 | dtype=self.token_embedding.weight.dtype |
47 | ) | 50 | ) |
48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
50 | 52 | ||
51 | def resize(self, size: int): | 53 | def resize(self, size: int): |
52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 54 | self.temp_token_embedding = resize_embedding( |
55 | self.temp_token_embedding, | ||
56 | size - self.num_permanent_embeddings, | ||
57 | self.initializer_factor | ||
58 | ) | ||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 60 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
@@ -71,7 +77,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
71 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
72 | 78 | ||
73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | 80 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) |
81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | ||
75 | device=self.temp_token_embedding.weight.device, | 82 | device=self.temp_token_embedding.weight.device, |
76 | dtype=self.temp_token_embedding.weight.dtype, | 83 | dtype=self.temp_token_embedding.weight.dtype, |
77 | ) | 84 | ) |
@@ -85,16 +92,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
85 | 92 | ||
86 | def persist(self): | 93 | def persist(self): |
87 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 94 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
88 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 95 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
96 | self.init_temp_embeddings() | ||
89 | 97 | ||
90 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 98 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
91 | if isinstance(input_ids, list): | 99 | if isinstance(input_ids, list): |
92 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 100 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
93 | 101 | ||
102 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
103 | |||
94 | embeds = self.token_embedding(input_ids) | 104 | embeds = self.token_embedding(input_ids) |
95 | 105 | ||
96 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 106 | embeds_mask = torch.isin(input_ids, all_temp_token_ids) |
97 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 107 | temp_token_ids = input_ids[embeds_mask] |
108 | |||
109 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
110 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
111 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
112 | |||
113 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | ||
98 | 114 | ||
99 | return embeds | 115 | return embeds |
100 | 116 | ||