summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
blob: 7d63ffbd331d9bc5553f3dab165bd1d475145540 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Union, Optional
from pathlib import Path

import torch
import torch.nn as nn

from safetensors import safe_open
from safetensors.torch import save_file

from transformers import CLIPTextModel
from transformers.models.clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings


def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding:
    old_num_embeddings, old_embedding_dim = old_embedding.weight.size()

    new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim)
    new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype)
    new_embedding.weight.data.zero_()
    new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data

    return new_embedding


class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
    def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
        super().__init__(config)

        self.token_embedding = embeddings.token_embedding
        self.position_embedding = embeddings.position_embedding

        self.temp_token_embedding = nn.Embedding(
            self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
        self.temp_token_embedding.weight.data.zero_()
        self.temp_token_ids = torch.tensor([])

    def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        if initializer is not None:
            if isinstance(initializer, int):
                initializer = [initializer]

            if isinstance(initializer, list):
                initializer = (initializer * len(token_ids))[:len(token_ids)]

                with torch.no_grad():
                    initializer = self.get_embed(initializer)

        self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
        self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))

        token_ids = torch.tensor(token_ids)

        self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])

        if initializer is not None:
            self.temp_token_embedding.weight.data[token_ids] = initializer
        else:
            self.temp_token_embedding.weight.data[token_ids].zero_()

    def load_embed(self, input_ids: list[int], filename: Path):
        with safe_open(filename, framework="pt", device="cpu") as file:
            self.add_embed(input_ids, file.get_tensor("embed"))

    def save_embed(self, input_ids: list[int], filename: Path):
        save_file({"embed": self.get_embed(input_ids)}, filename)

    def make_permanent(self):
        self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
        self.temp_token_ids = torch.tensor([])

    def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
        if isinstance(input_ids, list):
            input_ids = torch.tensor(input_ids)

        mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device))

        embeds = self.token_embedding(input_ids)
        embeds[mask] = self.temp_token_embedding(input_ids)[mask]

        return embeds

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            inputs_embeds = self.get_embed(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
    text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
    text_encoder.text_model.embeddings = text_embeddings
    return text_embeddings