summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
committerVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
commitba9fd1a10746d85d2502c8a79ac49db63d346b04 (patch)
tree568bf65a0a4dcea2c34de4006b5761d0d6564307 /models/clip
parentFix (diff)
downloadtextual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.gz
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.bz2
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.zip
Update
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6fda33c..dc4708a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 46
47 self.token_override_embedding = PseudoSparseEmbedding( 47 self.token_override_embedding = PseudoSparseEmbedding(
48 self.token_embedding.embedding_dim, 48 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p,
49 device=self.token_embedding.weight.device, 50 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 51 dtype=self.token_embedding.weight.dtype,
51 ) 52 )
@@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
134 return embeddings 135 return embeddings
135 136
136 137
137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: 138def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings:
138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p)
139 text_encoder.text_model.embeddings = text_embeddings 140 text_encoder.text_model.embeddings = text_embeddings
140 return text_embeddings 141 return text_embeddings