1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
|
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout_p = lora_dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = nn.Identity()
self.merged = False
self.merge_weights = merge_weights
class LoraEmbedding(nn.Embedding, LoraLayer):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoraLayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long))
self.trainable_ids -= 1
if r > 0:
self.lora_A = nn.ParameterList()
self.lora_B = nn.Linear(r, embedding_dim, bias=False)
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.reset_parameters()
def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
n = min(self.num_embeddings, new_num_embeddings)
new_emb = LoraEmbedding(
new_num_embeddings,
self.embedding_dim,
self.r,
self.lora_alpha,
self.lora_dropout_p,
device=self.weight.device,
dtype=self.weight.dtype
)
if initializer_factor is not None:
new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
else:
nn.init.zeros_(new_emb.weight.data)
new_emb.weight.data[:n, :] = self.weight.data[:n, :]
new_emb.lora_A = self.lora_A
new_emb.lora_B = self.lora_B
new_emb.trainable_ids[:n] = self.trainable_ids[:n]
return new_emb
def mark_trainable(self, input_ids: torch.LongTensor):
trainable_ids = self.trainable_ids[input_ids]
new_ids = input_ids[trainable_ids == -1]
if new_ids.shape[0] == 0:
return
n1 = len(self.lora_A)
n2 = n1 + new_ids.shape[0]
self.trainable_ids[new_ids] = torch.arange(n1, n2)
for _ in new_ids:
self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True))
def get_weights(self, input_ids: torch.Tensor):
trainable_ids = self.trainable_ids[input_ids]
mask = ~(trainable_ids == -1)
trainable_ids = trainable_ids[mask]
elems = [self.lora_A[id] for id in trainable_ids]
if len(elems) != 0:
weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
else:
weights = self.weight.new_zeros(self.embedding_dim)
return weights, mask
def persist(self):
if self.r > 0:
weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
if weights is not None:
self.weight[mask].data += weights
self.trainable_ids[:] = -1
self.lora_A = nn.ParameterList()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
self.trainable_ids[:] = -1
self.lora_A = nn.ParameterList()
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if self.merge_weights and self.merged:
if self.r > 0:
weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
self.weight[mask].data -= weights
self.merged = False
def eval(self):
nn.Embedding.eval(self)
if self.merge_weights and not self.merged:
if self.r > 0:
weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
self.weight[mask].data += weights
self.merged = True
def forward(self, input_ids: torch.LongTensor):
result = nn.Embedding.forward(self, input_ids)
if self.r > 0 and not self.merged:
weights, mask = self.get_weights(input_ids)
result[mask] += weights
return result
|