1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
|
from typing import Union, Optional
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import CLIPTextModel
from transformers.models.clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
from models.sparse import SparseEmbedding
class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
def __init__(
self,
config: CLIPTextConfig,
embeddings: CLIPTextEmbeddings,
alpha: int = 8,
dropout: float = 0.0,
):
super().__init__(config)
self.position_embedding = embeddings.position_embedding
self.initializer_factor = config.initializer_factor
self.token_embedding = SparseEmbedding(
self.token_embedding.num_embeddings,
self.token_embedding.embedding_dim,
alpha,
dropout,
)
self.token_embedding.weight = embeddings.token_embedding.weight
def resize(self, size: int):
self.token_embedding = self.token_embedding.new_resized(
size, self.initializer_factor
)
def add_embed(
self,
token_ids: Union[int, list[int]],
initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None,
initializer_noise: float = 0.0,
):
if isinstance(token_ids, int):
token_ids = [token_ids]
if initializer is None:
initializer = token_ids
if isinstance(initializer, int):
initializer = [initializer]
if isinstance(initializer, list):
initializer = (initializer * len(token_ids))[: len(token_ids)]
with torch.no_grad():
initializer = self.get_embed(initializer)
initializer = initializer.to(
device=self.token_embedding.weight.device,
dtype=self.token_embedding.weight.dtype,
)
if initializer_noise != 0:
initializer += torch.randn_like(initializer) * initializer_noise
token_ids = torch.tensor(token_ids, dtype=torch.long)
self.token_embedding.mark_trainable(token_ids)
self.token_embedding.weight.data[token_ids] = initializer
def load_embed(self, input_ids: list[int], filename: Path):
with safe_open(filename, framework="pt", device="cpu") as file:
self.add_embed(input_ids, file.get_tensor("embed"))
def save_embed(self, input_ids: list[int], filename: Path):
save_file({"embed": self.get_embed(input_ids)}, filename)
def persist(self, clear=False):
self.token_embedding.persist(clear)
def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
if isinstance(input_ids, list):
input_ids = torch.tensor(
input_ids, device=self.token_embedding.weight.device, dtype=torch.long
)
return self.token_embedding(input_ids)
def patch_managed_embeddings(
text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0
) -> ManagedCLIPTextEmbeddings:
if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
return text_encoder.text_model.embeddings
text_embeddings = ManagedCLIPTextEmbeddings(
text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout
)
text_encoder.text_model.embeddings = text_embeddings
return text_embeddings
|