diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 7 | ||||
| -rw-r--r-- | models/sparse.py | 13 |
2 files changed, 14 insertions, 6 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): |
| 41 | super().__init__(config) | 41 | super().__init__(config) |
| 42 | 42 | ||
| 43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
| @@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 46 | 46 | ||
| 47 | self.token_override_embedding = PseudoSparseEmbedding( | 47 | self.token_override_embedding = PseudoSparseEmbedding( |
| 48 | self.token_embedding.embedding_dim, | 48 | self.token_embedding.embedding_dim, |
| 49 | dropout_p=dropout_p, | ||
| 49 | device=self.token_embedding.weight.device, | 50 | device=self.token_embedding.weight.device, |
| 50 | dtype=self.token_embedding.weight.dtype, | 51 | dtype=self.token_embedding.weight.dtype, |
| 51 | ) | 52 | ) |
| @@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 134 | return embeddings | 135 | return embeddings |
| 135 | 136 | ||
| 136 | 137 | ||
| 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: |
| 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) |
| 139 | text_encoder.text_model.embeddings = text_embeddings | 140 | text_encoder.text_model.embeddings = text_embeddings |
| 140 | return text_embeddings | 141 | return text_embeddings |
diff --git a/models/sparse.py b/models/sparse.py index d706db5..bcb2897 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -5,22 +5,29 @@ import torch.nn as nn | |||
| 5 | 5 | ||
| 6 | 6 | ||
| 7 | class PseudoSparseEmbedding(nn.Module): | 7 | class PseudoSparseEmbedding(nn.Module): |
| 8 | def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): | 8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): |
| 9 | super().__init__() | 9 | super().__init__() |
| 10 | 10 | ||
| 11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
| 12 | self.dtype = dtype | 12 | self.dtype = dtype |
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | |||
| 15 | if dropout_p > 0.0: | ||
| 16 | self.dropout = nn.Dropout(p=dropout_p) | ||
| 17 | else: | ||
| 18 | self.dropout = lambda x: x | ||
| 19 | |||
| 14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | 20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
| 15 | 21 | ||
| 16 | def forward(self, input_ids: torch.LongTensor): | 22 | def forward(self, input_ids: torch.LongTensor): |
| 17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 23 | input_ids = input_ids.to(self.mapping.device) |
| 24 | ids = self.mapping[input_ids] | ||
| 18 | mask = ~(ids == -1) | 25 | mask = ~(ids == -1) |
| 19 | 26 | ||
| 20 | if torch.all(~mask): | 27 | if torch.all(~mask): |
| 21 | embs = None | 28 | embs = None |
| 22 | else: | 29 | else: |
| 23 | embs = torch.stack([self.params[id] for id in ids[mask]]) | 30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) |
| 24 | 31 | ||
| 25 | return embs, mask | 32 | return embs, mask |
| 26 | 33 | ||
