summaryrefslogtreecommitdiffstats
path: root/training/ti.py
blob: a5e407b4698f97eee5c2df91b4bb3987afa034b5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Optional

import torch
import torch.nn as nn

from transformers.models.clip import CLIPTextModel, CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings


def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
    text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)

    text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
    text_embeddings.token_embedding.weight.requires_grad = False

    text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
    text_embeddings.position_embedding.weight.requires_grad = False

    text_encoder.text_model.embeddings = text_embeddings

    return text_embeddings


class TrainableEmbeddings(CLIPTextEmbeddings):
    def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
        super().__init__(config)

        self.train_indices = torch.tensor(new_ids)

        self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
        self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
        self.trainable_embedding.weight.requires_grad = True

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        device = input_ids.device
        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            mask = torch.isin(input_ids, self.train_indices.to(device))
            inputs_embeds = self.token_embedding(input_ids)
            inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask]

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings