1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
|
from typing import Optional
import math
import torch
import torch.nn as nn
class LoraLayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout_p = lora_dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = nn.Identity()
self.merged = False
self.merge_weights = merge_weights
class LoraEmbedding(nn.Embedding, LoraLayer):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoraLayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
self.lora_A = nn.ParameterList()
self.lora_B = nn.Linear(r, embedding_dim, bias=False)
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.reset_parameters()
def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
n = min(self.num_embeddings, new_num_embeddings)
new_emb = LoraEmbedding(
new_num_embeddings,
self.embedding_dim,
self.r,
self.lora_alpha,
self.lora_dropout_p,
device=self.weight.device,
dtype=self.weight.dtype
)
if initializer_factor is not None:
new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
else:
nn.init.zeros_(new_emb.weight.data)
new_emb.weight.data[:n, :] = self.weight.data[:n, :]
for param in self.lora_A:
new_emb.lora_A.append(param)
new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
new_emb.trainable_ids[:n] = self.trainable_ids[:n]
return new_emb
def mark_trainable(self, input_ids: torch.LongTensor):
trainable_ids = self.trainable_ids[input_ids]
new_ids = input_ids[trainable_ids == -1]
if new_ids.shape[0] == 0:
return
n1 = len(self.lora_A)
n2 = n1 + new_ids.shape[0]
self.trainable_ids[new_ids] = torch.arange(n1, n2)
for _ in new_ids:
w = self.weight.new_zeros(self.r)
self.lora_A.append(w)
if len(self.lora_A) > 1:
elems = torch.stack([param for param in self.lora_A])
nn.init.kaiming_uniform_(elems, a=math.sqrt(5))
def get_weights(self, input_ids: torch.Tensor):
if len(input_ids.shape) != 1:
return torch.stack([self.get_weights(batch) for batch in input_ids])
weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
if not self.merged:
trainable_ids = self.trainable_ids[input_ids]
mask = ~(trainable_ids == -1)
elems = [self.lora_A[id] for id in trainable_ids[mask]]
if len(elems) != 0:
w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
weights[mask] = w.to(dtype=weights.dtype)
return weights
def persist(self):
self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
self.trainable_ids[:] = -1
self.lora_A = nn.ParameterList()
nn.init.zeros_(self.lora_B.weight)
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, "lora_A"):
self.trainable_ids[:] = -1
self.lora_A = nn.ParameterList()
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
self.merged = True
elif self.merge_weights and self.merged:
self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
self.merged = False
def eval(self):
nn.Embedding.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, input_ids: torch.LongTensor):
result = nn.Embedding.forward(self, input_ids)
result += self.get_weights(input_ids)
return result
|