1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
from typing import Optional
import torch
import torch.nn as nn
from transformers.models.clip import CLIPTextModel, CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)
text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
text_embeddings.token_embedding.weight.requires_grad = False
text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
text_embeddings.position_embedding.weight.requires_grad = False
text_encoder.text_model.embeddings = text_embeddings
return text_embeddings
class TrainableEmbeddings(CLIPTextEmbeddings):
def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
super().__init__(config)
self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
self.train_indices = torch.tensor(new_ids)
self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim)
self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices]
self.trainable_embedding.weight.requires_grad = True
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None]
trainable_input_ids = torch.tensor([
[
self.id_mapping[id] if id in self.id_mapping else 0
for id in batch
]
for batch in input_ids
], device=input_ids.device)
inputs_embeds = torch.where(
mask,
self.trainable_embedding(trainable_input_ids),
self.token_embedding(input_ids)
)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
|