diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 15 | ||||
| -rw-r--r-- | models/sparse.py | 14 |
2 files changed, 15 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index a356434..63a141f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): |
| 41 | super().__init__(config) | 41 | super().__init__(config) |
| 42 | 42 | ||
| 43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
| @@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 49 | device=self.token_embedding.weight.device, | 49 | device=self.token_embedding.weight.device, |
| 50 | dtype=self.token_embedding.weight.dtype, | 50 | dtype=self.token_embedding.weight.dtype, |
| 51 | ) | 51 | ) |
| 52 | self.alpha = alpha | ||
| 53 | 52 | ||
| 54 | def resize(self, size: int): | 53 | def resize(self, size: int): |
| 55 | self.token_override_embedding.resize(size) | 54 | self.token_override_embedding.resize(size) |
| @@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 86 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 88 | 87 | ||
| 89 | self.token_embedding.weight.data[token_ids] = initializer | 88 | self.token_embedding.weight.data[token_ids] = initializer |
| 90 | self.token_override_embedding.set(token_ids) | 89 | self.token_override_embedding.set(token_ids, initializer) |
| 91 | 90 | ||
| 92 | def load_embed(self, input_ids: list[int], filename: Path): | 91 | def load_embed(self, input_ids: list[int], filename: Path): |
| 93 | with safe_open(filename, framework="pt", device="cpu") as file: | 92 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 101 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
| 102 | if embs is not None: | 101 | if embs is not None: |
| 103 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
| 104 | self.token_embedding.weight.data[input_ids] += self.alpha * embs | 103 | self.token_embedding.weight.data[input_ids] = embs |
| 105 | self.token_override_embedding.unset(input_ids) | 104 | self.token_override_embedding.unset(input_ids) |
| 106 | 105 | ||
| 107 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 106 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 108 | if isinstance(input_ids, list): | 107 | if isinstance(input_ids, list): |
| @@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 111 | embs = self.token_embedding(input_ids) | 110 | embs = self.token_embedding(input_ids) |
| 112 | embs_override, mask = self.token_override_embedding(input_ids) | 111 | embs_override, mask = self.token_override_embedding(input_ids) |
| 113 | if embs_override is not None: | 112 | if embs_override is not None: |
| 114 | embs[mask] += self.alpha * embs_override | 113 | embs[mask] = embs_override |
| 115 | 114 | ||
| 116 | return embs | 115 | return embs |
| 117 | 116 | ||
| @@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 135 | return embeddings | 134 | return embeddings |
| 136 | 135 | ||
| 137 | 136 | ||
| 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: | 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: |
| 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) | 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
| 140 | text_encoder.text_model.embeddings = text_embeddings | 139 | text_encoder.text_model.embeddings = text_embeddings |
| 141 | return text_embeddings | 140 | return text_embeddings |
diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) |
| 15 | 15 | ||
| 16 | def forward(self, input_ids: Optional[torch.LongTensor] = None): | 16 | def forward(self, input_ids: torch.LongTensor): |
| 17 | if input_ids is None: | ||
| 18 | input_ids = torch.arange(self.mapping.shape[0]) | ||
| 19 | |||
| 20 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
| 21 | mask = ~(ids == -1) | 18 | mask = ~(ids == -1) |
| 22 | 19 | ||
| @@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 43 | else: | 40 | else: |
| 44 | return [self.set(id) for id in input_ids] | 41 | return [self.set(id) for id in input_ids] |
| 45 | 42 | ||
| 43 | if tensor is None: | ||
| 44 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
| 45 | |||
| 46 | if tensor.shape[-1] != self.embedding_dim: | ||
| 47 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
| 48 | |||
| 46 | id = self.mapping[input_ids] | 49 | id = self.mapping[input_ids] |
| 47 | 50 | ||
| 48 | if id == -1: | 51 | if id == -1: |
| @@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 50 | self.mapping[input_ids] = id | 53 | self.mapping[input_ids] = id |
| 51 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | 54 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) |
| 52 | 55 | ||
| 53 | self.params[id] = tensor if tensor is not None else torch.zeros( | 56 | self.params[id] = tensor |
| 54 | self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
| 55 | 57 | ||
| 56 | def unset(self, input_ids: torch.LongTensor): | 58 | def unset(self, input_ids: torch.LongTensor): |
| 57 | self.mapping[input_ids] = -1 | 59 | self.mapping[input_ids] = -1 |
