diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 34 |
1 files changed, 9 insertions, 25 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2b315c4..2d60c28 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -38,24 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 38 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
| 39 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
| 40 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
| 41 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | ||
| 42 | self.init_temp_embeddings() | ||
| 43 | 41 | ||
| 44 | def init_temp_embeddings(self): | ||
| 45 | self.temp_token_embedding = nn.Embedding( | 42 | self.temp_token_embedding = nn.Embedding( |
| 46 | 0, | 43 | self.token_embedding.num_embeddings, |
| 47 | self.token_embedding.embedding_dim, | 44 | self.token_embedding.embedding_dim, |
| 48 | device=self.token_embedding.weight.device, | 45 | device=self.token_embedding.weight.device, |
| 49 | dtype=self.token_embedding.weight.dtype | 46 | dtype=self.token_embedding.weight.dtype |
| 50 | ) | 47 | ) |
| 48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 52 | 50 | ||
| 53 | def resize(self, size: int): | 51 | def resize(self, size: int): |
| 54 | self.temp_token_embedding = resize_embedding( | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
| 55 | self.temp_token_embedding, | ||
| 56 | size - self.num_permanent_embeddings, | ||
| 57 | self.initializer_factor | ||
| 58 | ) | ||
| 59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 60 | 54 | ||
| 61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| @@ -75,15 +69,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 75 | initializer = self.get_embed(initializer) | 69 | initializer = self.get_embed(initializer) |
| 76 | 70 | ||
| 77 | initializer = initializer.to( | 71 | initializer = initializer.to( |
| 78 | device=self.token_embedding.weight.device, | 72 | device=self.temp_token_embedding.weight.device, |
| 79 | dtype=self.token_embedding.weight.dtype, | 73 | dtype=self.temp_token_embedding.weight.dtype, |
| 80 | ) | 74 | ) |
| 81 | 75 | ||
| 82 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 76 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 83 | 77 | ||
| 84 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 85 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) | 79 | self.temp_token_embedding.weight.data[token_ids] = initializer |
| 86 | self.temp_token_embedding.weight.data[mask] = initializer | ||
| 87 | 80 | ||
| 88 | def load_embed(self, input_ids: list[int], filename: Path): | 81 | def load_embed(self, input_ids: list[int], filename: Path): |
| 89 | with safe_open(filename, framework="pt", device="cpu") as file: | 82 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -94,25 +87,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 94 | 87 | ||
| 95 | def persist(self): | 88 | def persist(self): |
| 96 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 89 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 97 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 90 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 98 | self.init_temp_embeddings() | ||
| 99 | 91 | ||
| 100 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 92 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 101 | if isinstance(input_ids, list): | 93 | if isinstance(input_ids, list): |
| 102 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 94 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 103 | 95 | ||
| 104 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
| 105 | |||
| 106 | embeds = self.token_embedding(input_ids) | 96 | embeds = self.token_embedding(input_ids) |
| 107 | 97 | ||
| 108 | embeds_mask = torch.isin(input_ids, all_temp_token_ids) | 98 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 109 | temp_token_ids = input_ids[embeds_mask] | 99 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
| 110 | |||
| 111 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
| 112 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
| 113 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
| 114 | |||
| 115 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) | ||
| 116 | 100 | ||
| 117 | return embeds | 101 | return embeds |
| 118 | 102 | ||
