summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
blob: afb74308636632fd12a2cc6bd27c878f76f3b2e5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Union, Optional
from pathlib import Path

import torch

from safetensors import safe_open
from safetensors.torch import save_file

from transformers import CLIPTextModel
from transformers.models.clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings

from models.sparse import SparseEmbedding


class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
    def __init__(
        self,
        config: CLIPTextConfig,
        embeddings: CLIPTextEmbeddings,
        alpha: int = 8,
        dropout: float = 0.0,
    ):
        super().__init__(config)

        self.position_embedding = embeddings.position_embedding
        self.initializer_factor = config.initializer_factor
        self.token_embedding = SparseEmbedding(
            self.token_embedding.num_embeddings,
            self.token_embedding.embedding_dim,
            alpha,
            dropout,
        )
        self.token_embedding.weight = embeddings.token_embedding.weight

    def resize(self, size: int):
        self.token_embedding = self.token_embedding.new_resized(
            size, self.initializer_factor
        )

    def add_embed(
        self,
        token_ids: Union[int, list[int]],
        initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None,
        initializer_noise: float = 0.0,
    ):
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        if initializer is None:
            initializer = token_ids

        if isinstance(initializer, int):
            initializer = [initializer]

        if isinstance(initializer, list):
            initializer = (initializer * len(token_ids))[: len(token_ids)]

            with torch.no_grad():
                initializer = self.get_embed(initializer)

        initializer = initializer.to(
            device=self.token_embedding.weight.device,
            dtype=self.token_embedding.weight.dtype,
        )

        if initializer_noise != 0:
            initializer += torch.randn_like(initializer) * initializer_noise

        token_ids = torch.tensor(token_ids, dtype=torch.long)

        self.token_embedding.mark_trainable(token_ids)
        self.token_embedding.weight.data[token_ids] = initializer

    def load_embed(self, input_ids: list[int], filename: Path):
        with safe_open(filename, framework="pt", device="cpu") as file:
            self.add_embed(input_ids, file.get_tensor("embed"))

    def save_embed(self, input_ids: list[int], filename: Path):
        save_file({"embed": self.get_embed(input_ids)}, filename)

    def persist(self, clear=False):
        self.token_embedding.persist(clear)

    def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
        if isinstance(input_ids, list):
            input_ids = torch.tensor(
                input_ids, device=self.token_embedding.weight.device, dtype=torch.long
            )

        return self.token_embedding(input_ids)


def patch_managed_embeddings(
    text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0
) -> ManagedCLIPTextEmbeddings:
    if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
        return text_encoder.text_model.embeddings

    text_embeddings = ManagedCLIPTextEmbeddings(
        text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout
    )
    text_encoder.text_model.embeddings = text_embeddings
    return text_embeddings