diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
37 | 37 | ||
38 | 38 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): |
41 | super().__init__(config) | 41 | super().__init__(config) |
42 | 42 | ||
43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
@@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
46 | 46 | ||
47 | self.token_override_embedding = PseudoSparseEmbedding( | 47 | self.token_override_embedding = PseudoSparseEmbedding( |
48 | self.token_embedding.embedding_dim, | 48 | self.token_embedding.embedding_dim, |
49 | dropout_p=dropout_p, | ||
49 | device=self.token_embedding.weight.device, | 50 | device=self.token_embedding.weight.device, |
50 | dtype=self.token_embedding.weight.dtype, | 51 | dtype=self.token_embedding.weight.dtype, |
51 | ) | 52 | ) |
@@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
134 | return embeddings | 135 | return embeddings |
135 | 136 | ||
136 | 137 | ||
137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: |
138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) |
139 | text_encoder.text_model.embeddings = text_embeddings | 140 | text_encoder.text_model.embeddings = text_embeddings |
140 | return text_embeddings | 141 | return text_embeddings |