From ba9fd1a10746d85d2502c8a79ac49db63d346b04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 11:29:31 +0200 Subject: Update --- models/clip/embeddings.py | 7 ++++--- models/sparse.py | 13 ++++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): super().__init__(config) self.token_embedding = embeddings.token_embedding @@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_override_embedding = PseudoSparseEmbedding( self.token_embedding.embedding_dim, + dropout_p=dropout_p, device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype, ) @@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) +def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/sparse.py b/models/sparse.py index d706db5..bcb2897 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -5,22 +5,29 @@ import torch.nn as nn class PseudoSparseEmbedding(nn.Module): - def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): + def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): super().__init__() self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() + + if dropout_p > 0.0: + self.dropout = nn.Dropout(p=dropout_p) + else: + self.dropout = lambda x: x + self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) def forward(self, input_ids: torch.LongTensor): - ids = self.mapping[input_ids.to(self.mapping.device)] + input_ids = input_ids.to(self.mapping.device) + ids = self.mapping[input_ids] mask = ~(ids == -1) if torch.all(~mask): embs = None else: - embs = torch.stack([self.params[id] for id in ids[mask]]) + embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) return embs, mask -- cgit v1.2.3-54-g00ecf