summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py7
-rw-r--r--models/sparse.py13
2 files changed, 14 insertions, 6 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6fda33c..dc4708a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 46
47 self.token_override_embedding = PseudoSparseEmbedding( 47 self.token_override_embedding = PseudoSparseEmbedding(
48 self.token_embedding.embedding_dim, 48 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p,
49 device=self.token_embedding.weight.device, 50 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 51 dtype=self.token_embedding.weight.dtype,
51 ) 52 )
@@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
134 return embeddings 135 return embeddings
135 136
136 137
137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: 138def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings:
138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p)
139 text_encoder.text_model.embeddings = text_embeddings 140 text_encoder.text_model.embeddings = text_embeddings
140 return text_embeddings 141 return text_embeddings
diff --git a/models/sparse.py b/models/sparse.py
index d706db5..bcb2897 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -5,22 +5,29 @@ import torch.nn as nn
5 5
6 6
7class PseudoSparseEmbedding(nn.Module): 7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): 8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__() 9 super().__init__()
10 10
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = lambda x: x
19
14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) 20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 21
16 def forward(self, input_ids: torch.LongTensor): 22 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
18 mask = ~(ids == -1) 25 mask = ~(ids == -1)
19 26
20 if torch.all(~mask): 27 if torch.all(~mask):
21 embs = None 28 embs = None
22 else: 29 else:
23 embs = torch.stack([self.params[id] for id in ids[mask]]) 30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
24 31
25 return embs, mask 32 return embs, mask
26 33