diff options
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding | |||
14 | 14 | ||
15 | 15 | ||
16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): | 17 | def __init__( |
18 | self, | ||
19 | config: CLIPTextConfig, | ||
20 | embeddings: CLIPTextEmbeddings, | ||
21 | alpha: int = 8, | ||
22 | dropout: float = 0.0, | ||
23 | ): | ||
18 | super().__init__(config) | 24 | super().__init__(config) |
19 | 25 | ||
20 | self.position_embedding = embeddings.position_embedding | 26 | self.position_embedding = embeddings.position_embedding |
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
28 | self.token_embedding.weight = embeddings.token_embedding.weight | 34 | self.token_embedding.weight = embeddings.token_embedding.weight |
29 | 35 | ||
30 | def resize(self, size: int): | 36 | def resize(self, size: int): |
31 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) | 37 | self.token_embedding = self.token_embedding.new_resized( |
38 | size, self.initializer_factor | ||
39 | ) | ||
32 | 40 | ||
33 | def add_embed( | 41 | def add_embed( |
34 | self, | 42 | self, |
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
46 | initializer = [initializer] | 54 | initializer = [initializer] |
47 | 55 | ||
48 | if isinstance(initializer, list): | 56 | if isinstance(initializer, list): |
49 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 57 | initializer = (initializer * len(token_ids))[: len(token_ids)] |
50 | 58 | ||
51 | with torch.no_grad(): | 59 | with torch.no_grad(): |
52 | initializer = self.get_embed(initializer) | 60 | initializer = self.get_embed(initializer) |
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
76 | 84 | ||
77 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
78 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
79 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 87 | input_ids = torch.tensor( |
88 | input_ids, device=self.token_embedding.weight.device, dtype=torch.long | ||
89 | ) | ||
80 | 90 | ||
81 | return self.token_embedding(input_ids) | 91 | return self.token_embedding(input_ids) |
82 | 92 | ||
83 | 93 | ||
84 | def patch_managed_embeddings( | 94 | def patch_managed_embeddings( |
85 | text_encoder: CLIPTextModel, | 95 | text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 |
86 | alpha: int = 8, | ||
87 | dropout: float = 0.0 | ||
88 | ) -> ManagedCLIPTextEmbeddings: | 96 | ) -> ManagedCLIPTextEmbeddings: |
89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | 97 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): |
90 | return text_encoder.text_model.embeddings | 98 | return text_encoder.text_model.embeddings |
91 | 99 | ||
92 | text_embeddings = ManagedCLIPTextEmbeddings( | 100 | text_embeddings = ManagedCLIPTextEmbeddings( |
93 | text_encoder.config, | 101 | text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout |
94 | text_encoder.text_model.embeddings, | ||
95 | alpha, | ||
96 | dropout | ||
97 | ) | 102 | ) |
98 | text_encoder.text_model.embeddings = text_embeddings | 103 | text_encoder.text_model.embeddings = text_embeddings |
99 | return text_embeddings | 104 | return text_embeddings |