summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py41
1 files changed, 9 insertions, 32 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index d02ccc3..8aaea8f 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -10,23 +10,21 @@ from transformers import CLIPTextModel
10from transformers.models.clip import CLIPTextConfig 10from transformers.models.clip import CLIPTextConfig
11from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 11from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
12 12
13from models.lora import LoraEmbedding 13from models.sparse import SparseEmbedding
14 14
15 15
16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): 17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0):
18 super().__init__(config) 18 super().__init__(config)
19 19
20 self.position_embedding = embeddings.position_embedding 20 self.position_embedding = embeddings.position_embedding
21 self.initializer_factor = config.initializer_factor 21 self.initializer_factor = config.initializer_factor
22 self.token_embedding = LoraEmbedding( 22 self.token_embedding = SparseEmbedding(
23 self.token_embedding.num_embeddings, 23 self.token_embedding.num_embeddings,
24 self.token_embedding.embedding_dim, 24 self.token_embedding.embedding_dim,
25 r, 25 alpha,
26 lora_alpha, 26 dropout,
27 lora_dropout,
28 ) 27 )
29
30 self.token_embedding.weight = embeddings.token_embedding.weight 28 self.token_embedding.weight = embeddings.token_embedding.weight
31 29
32 def resize(self, size: int): 30 def resize(self, size: int):
@@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
82 80
83 return self.token_embedding(input_ids) 81 return self.token_embedding(input_ids)
84 82
85 def forward(
86 self,
87 input_ids: Optional[torch.LongTensor] = None,
88 position_ids: Optional[torch.LongTensor] = None,
89 inputs_embeds: Optional[torch.FloatTensor] = None,
90 ) -> torch.Tensor:
91 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
92
93 if position_ids is None:
94 position_ids = self.position_ids[:, :seq_length]
95
96 if inputs_embeds is None:
97 inputs_embeds = self.get_embed(input_ids)
98
99 position_embeddings = self.position_embedding(position_ids)
100 embeddings = inputs_embeds + position_embeddings
101
102 return embeddings
103
104 83
105def patch_managed_embeddings( 84def patch_managed_embeddings(
106 text_encoder: CLIPTextModel, 85 text_encoder: CLIPTextModel,
107 r: int = 8, 86 alpha: int = 8,
108 lora_alpha: int = 8, 87 dropout: float = 0.0
109 lora_dropout: float = 0.0
110) -> ManagedCLIPTextEmbeddings: 88) -> ManagedCLIPTextEmbeddings:
111 text_embeddings = ManagedCLIPTextEmbeddings( 89 text_embeddings = ManagedCLIPTextEmbeddings(
112 text_encoder.config, 90 text_encoder.config,
113 text_encoder.text_model.embeddings, 91 text_encoder.text_model.embeddings,
114 r, 92 alpha,
115 lora_alpha, 93 dropout
116 lora_dropout
117 ) 94 )
118 text_encoder.text_model.embeddings = text_embeddings 95 text_encoder.text_model.embeddings = text_embeddings
119 return text_embeddings 96 return text_embeddings