1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout_p = lora_dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = nn.Identity()
self.merged = False
self.merge_weights = merge_weights
class LoraEmbedding(nn.Embedding, LoraLayer):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoraLayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long))
self.trainable_ids -= 1
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.reset_parameters()
def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
n = min(self.num_embeddings, new_num_embeddings)
new_emb = LoraEmbedding(
new_num_embeddings,
self.embedding_dim,
self.r,
self.lora_alpha,
self.lora_dropout_p,
device=self.weight.device,
dtype=self.weight.dtype
)
if initializer_factor is not None:
new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
else:
nn.init.zeros_(new_emb.weight.data)
new_emb.weight.data[:n, :] = self.weight.data[:n, :]
new_emb.lora_A = self.lora_A
new_emb.lora_B = self.lora_B
new_emb.trainable_ids[:n] = self.trainable_ids[:n]
return new_emb
def mark_trainable(self, input_ids):
trainable_ids = self.trainable_ids[input_ids]
new_ids = trainable_ids[trainable_ids == -1]
if new_ids.shape[0] == 0:
return
n1 = self.lora_A.shape[1]
n2 = n1 + new_ids.shape[0]
self.trainable_ids[new_ids] = torch.arange(n1, n2)
lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2)))
self.lora_A = lora_A
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if self.merge_weights and self.merged:
if self.r > 0:
mask = ~(self.trainable_ids == -1)
trainable_ids = self.trainable_ids[mask]
self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling
self.merged = False
def eval(self):
nn.Embedding.eval(self)
if self.merge_weights and not self.merged:
if self.r > 0:
mask = ~(self.trainable_ids == -1)
trainable_ids = self.trainable_ids[mask]
self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, input_ids: torch.Tensor):
result = nn.Embedding.forward(self, input_ids)
if self.r > 0 and not self.merged:
trainable_ids = self.trainable_ids[input_ids]
mask = ~(trainable_ids == -1)
trainable_ids = trainable_ids[mask]
after_A = F.embedding(
trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result[mask] += (after_A @ self.lora_B.T) * self.scaling
return result
|