1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
|
from typing import Union, Optional
from pathlib import Path
import torch
import torch.nn as nn
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import CLIPTextModel
from transformers.models.clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding:
old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
if old_num_embeddings == new_num_embeddings:
return old_embedding
n = min(old_num_embeddings, new_num_embeddings)
new_embedding = nn.Embedding(
new_num_embeddings,
old_embedding_dim,
device=old_embedding.weight.device,
dtype=old_embedding.weight.dtype
)
new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
return new_embedding
class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
super().__init__(config)
self.token_embedding = embeddings.token_embedding
self.position_embedding = embeddings.position_embedding
self.initializer_factor = config.initializer_factor
self.temp_token_embedding = nn.Embedding(
self.token_embedding.num_embeddings,
self.token_embedding.embedding_dim,
device=self.token_embedding.weight.device,
dtype=self.token_embedding.weight.dtype
)
self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
self.temp_token_ids = torch.tensor([], dtype=torch.long)
def resize(self, size: int):
self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
if isinstance(token_ids, int):
token_ids = [token_ids]
if initializer is None:
initializer = token_ids
if isinstance(initializer, int):
initializer = [initializer]
if isinstance(initializer, list):
initializer = (initializer * len(token_ids))[:len(token_ids)]
with torch.no_grad():
initializer = self.get_embed(initializer)
token_ids = torch.tensor(token_ids, dtype=torch.long)
self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
self.temp_token_embedding.weight.data[token_ids] = initializer.to(
device=self.temp_token_embedding.weight.device,
dtype=self.temp_token_embedding.weight.dtype,
)
def load_embed(self, input_ids: list[int], filename: Path):
with safe_open(filename, framework="pt", device="cpu") as file:
self.add_embed(input_ids, file.get_tensor("embed"))
def save_embed(self, input_ids: list[int], filename: Path):
save_file({"embed": self.get_embed(input_ids)}, filename)
def persist(self):
self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
self.temp_token_ids = torch.tensor([], dtype=torch.long)
def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
if isinstance(input_ids, list):
input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
embeds = self.token_embedding(input_ids)
mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
embeds[mask] = self.temp_token_embedding(input_ids)[mask]
return embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.get_embed(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
text_encoder.text_model.embeddings = text_embeddings
return text_embeddings
|