summaryrefslogtreecommitdiffstats
path: root/training/ti.py
blob: 2efd2f2abe7471d7cee62b42ecb1cb6b0f7a76d0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Optional

import torch
import torch.nn as nn

from transformers.models.clip import CLIPTextModel, CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings


def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
    text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)
    text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight
    text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight
    text_encoder.text_model.embeddings = text_embeddings
    return text_embeddings


class TrainableEmbeddings(CLIPTextEmbeddings):
    def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
        super().__init__(config)

        self.token_embedding.weight.requires_grad = False
        self.position_embedding.weight.requires_grad = False

        self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}

        indices = torch.arange(self.token_embedding.num_embeddings)
        self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]

        self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices])
        self.trainable_embedding.weight.requires_grad = True

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            mask = torch.isin(
                input_ids,
                self.train_indices.to(input_ids.device)
            ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim)

            trainable_input_ids = torch.tensor([
                [
                    self.id_mapping[id] if id in self.id_mapping else 0
                    for id in batch
                ]
                for batch in input_ids
            ], device=input_ids.device)

            inputs_embeds = torch.where(
                mask,
                self.trainable_embedding(trainable_input_ids),
                self.token_embedding(input_ids)
            )

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings