summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 7c7f2ac..8c3c6d4 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding
14 14
15 15
16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): 17 def __init__(
18 self,
19 config: CLIPTextConfig,
20 embeddings: CLIPTextEmbeddings,
21 alpha: int = 8,
22 dropout: float = 0.0,
23 ):
18 super().__init__(config) 24 super().__init__(config)
19 25
20 self.position_embedding = embeddings.position_embedding 26 self.position_embedding = embeddings.position_embedding
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
28 self.token_embedding.weight = embeddings.token_embedding.weight 34 self.token_embedding.weight = embeddings.token_embedding.weight
29 35
30 def resize(self, size: int): 36 def resize(self, size: int):
31 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) 37 self.token_embedding = self.token_embedding.new_resized(
38 size, self.initializer_factor
39 )
32 40
33 def add_embed( 41 def add_embed(
34 self, 42 self,
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 initializer = [initializer] 54 initializer = [initializer]
47 55
48 if isinstance(initializer, list): 56 if isinstance(initializer, list):
49 initializer = (initializer * len(token_ids))[:len(token_ids)] 57 initializer = (initializer * len(token_ids))[: len(token_ids)]
50 58
51 with torch.no_grad(): 59 with torch.no_grad():
52 initializer = self.get_embed(initializer) 60 initializer = self.get_embed(initializer)
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 84
77 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
78 if isinstance(input_ids, list): 86 if isinstance(input_ids, list):
79 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 87 input_ids = torch.tensor(
88 input_ids, device=self.token_embedding.weight.device, dtype=torch.long
89 )
80 90
81 return self.token_embedding(input_ids) 91 return self.token_embedding(input_ids)
82 92
83 93
84def patch_managed_embeddings( 94def patch_managed_embeddings(
85 text_encoder: CLIPTextModel, 95 text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0
86 alpha: int = 8,
87 dropout: float = 0.0
88) -> ManagedCLIPTextEmbeddings: 96) -> ManagedCLIPTextEmbeddings:
89 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): 97 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
90 return text_encoder.text_model.embeddings 98 return text_encoder.text_model.embeddings
91 99
92 text_embeddings = ManagedCLIPTextEmbeddings( 100 text_embeddings = ManagedCLIPTextEmbeddings(
93 text_encoder.config, 101 text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout
94 text_encoder.text_model.embeddings,
95 alpha,
96 dropout
97 ) 102 )
98 text_encoder.text_model.embeddings = text_embeddings 103 text_encoder.text_model.embeddings = text_embeddings
99 return text_embeddings 104 return text_embeddings