diff options
Diffstat (limited to 'models/clip/embeddings.py')
| -rw-r--r-- | models/clip/embeddings.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): |
| 41 | super().__init__(config) | 41 | super().__init__(config) |
| 42 | 42 | ||
| 43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
| @@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 46 | 46 | ||
| 47 | self.token_override_embedding = PseudoSparseEmbedding( | 47 | self.token_override_embedding = PseudoSparseEmbedding( |
| 48 | self.token_embedding.embedding_dim, | 48 | self.token_embedding.embedding_dim, |
| 49 | dropout_p=dropout_p, | ||
| 49 | device=self.token_embedding.weight.device, | 50 | device=self.token_embedding.weight.device, |
| 50 | dtype=self.token_embedding.weight.dtype, | 51 | dtype=self.token_embedding.weight.dtype, |
| 51 | ) | 52 | ) |
| @@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 134 | return embeddings | 135 | return embeddings |
| 135 | 136 | ||
| 136 | 137 | ||
| 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: |
| 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) |
| 139 | text_encoder.text_model.embeddings = text_embeddings | 140 | text_encoder.text_model.embeddings = text_embeddings |
| 140 | return text_embeddings | 141 | return text_embeddings |
