diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 41 |
1 files changed, 9 insertions, 32 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -10,23 +10,21 @@ from transformers import CLIPTextModel | |||
10 | from transformers.models.clip import CLIPTextConfig | 10 | from transformers.models.clip import CLIPTextConfig |
11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
12 | 12 | ||
13 | from models.lora import LoraEmbedding | 13 | from models.sparse import SparseEmbedding |
14 | 14 | ||
15 | 15 | ||
16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): | 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): |
18 | super().__init__(config) | 18 | super().__init__(config) |
19 | 19 | ||
20 | self.position_embedding = embeddings.position_embedding | 20 | self.position_embedding = embeddings.position_embedding |
21 | self.initializer_factor = config.initializer_factor | 21 | self.initializer_factor = config.initializer_factor |
22 | self.token_embedding = LoraEmbedding( | 22 | self.token_embedding = SparseEmbedding( |
23 | self.token_embedding.num_embeddings, | 23 | self.token_embedding.num_embeddings, |
24 | self.token_embedding.embedding_dim, | 24 | self.token_embedding.embedding_dim, |
25 | r, | 25 | alpha, |
26 | lora_alpha, | 26 | dropout, |
27 | lora_dropout, | ||
28 | ) | 27 | ) |
29 | |||
30 | self.token_embedding.weight = embeddings.token_embedding.weight | 28 | self.token_embedding.weight = embeddings.token_embedding.weight |
31 | 29 | ||
32 | def resize(self, size: int): | 30 | def resize(self, size: int): |
@@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
82 | 80 | ||
83 | return self.token_embedding(input_ids) | 81 | return self.token_embedding(input_ids) |
84 | 82 | ||
85 | def forward( | ||
86 | self, | ||
87 | input_ids: Optional[torch.LongTensor] = None, | ||
88 | position_ids: Optional[torch.LongTensor] = None, | ||
89 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
90 | ) -> torch.Tensor: | ||
91 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
92 | |||
93 | if position_ids is None: | ||
94 | position_ids = self.position_ids[:, :seq_length] | ||
95 | |||
96 | if inputs_embeds is None: | ||
97 | inputs_embeds = self.get_embed(input_ids) | ||
98 | |||
99 | position_embeddings = self.position_embedding(position_ids) | ||
100 | embeddings = inputs_embeds + position_embeddings | ||
101 | |||
102 | return embeddings | ||
103 | |||
104 | 83 | ||
105 | def patch_managed_embeddings( | 84 | def patch_managed_embeddings( |
106 | text_encoder: CLIPTextModel, | 85 | text_encoder: CLIPTextModel, |
107 | r: int = 8, | 86 | alpha: int = 8, |
108 | lora_alpha: int = 8, | 87 | dropout: float = 0.0 |
109 | lora_dropout: float = 0.0 | ||
110 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
111 | text_embeddings = ManagedCLIPTextEmbeddings( | 89 | text_embeddings = ManagedCLIPTextEmbeddings( |
112 | text_encoder.config, | 90 | text_encoder.config, |
113 | text_encoder.text_model.embeddings, | 91 | text_encoder.text_model.embeddings, |
114 | r, | 92 | alpha, |
115 | lora_alpha, | 93 | dropout |
116 | lora_dropout | ||
117 | ) | 94 | ) |
118 | text_encoder.text_model.embeddings = text_embeddings | 95 | text_encoder.text_model.embeddings = text_embeddings |
119 | return text_embeddings | 96 | return text_embeddings |