summaryrefslogtreecommitdiffstats
path: root/models/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/lora.py')
-rw-r--r--models/lora.py131
1 files changed, 131 insertions, 0 deletions
diff --git a/models/lora.py b/models/lora.py
new file mode 100644
index 0000000..c0f74a6
--- /dev/null
+++ b/models/lora.py
@@ -0,0 +1,131 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6
7
8class LoraLayer():
9 def __init__(
10 self,
11 r: int,
12 lora_alpha: int,
13 lora_dropout: float,
14 merge_weights: bool,
15 ):
16 self.r = r
17 self.lora_alpha = lora_alpha
18 self.lora_dropout_p = lora_dropout
19
20 if lora_dropout > 0.:
21 self.lora_dropout = nn.Dropout(p=lora_dropout)
22 else:
23 self.lora_dropout = nn.Identity()
24
25 self.merged = False
26 self.merge_weights = merge_weights
27
28
29class LoraEmbedding(nn.Embedding, LoraLayer):
30 def __init__(
31 self,
32 num_embeddings: int,
33 embedding_dim: int,
34 r: int = 0,
35 lora_alpha: int = 1,
36 lora_dropout: float = 0.0,
37 merge_weights: bool = True,
38 **kwargs
39 ):
40 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
41 LoraLayer.__init__(
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 )
44
45 self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long))
46 self.trainable_ids -= 1
47
48 if r > 0:
49 self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0)))
50 self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
51 self.scaling = self.lora_alpha / self.r
52 self.weight.requires_grad = False
53
54 self.reset_parameters()
55
56 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
57 n = min(self.num_embeddings, new_num_embeddings)
58
59 new_emb = LoraEmbedding(
60 new_num_embeddings,
61 self.embedding_dim,
62 self.r,
63 self.lora_alpha,
64 self.lora_dropout_p,
65 device=self.weight.device,
66 dtype=self.weight.dtype
67 )
68 if initializer_factor is not None:
69 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
70 else:
71 nn.init.zeros_(new_emb.weight.data)
72 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
73 new_emb.lora_A = self.lora_A
74 new_emb.lora_B = self.lora_B
75 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
76
77 return new_emb
78
79 def mark_trainable(self, input_ids):
80 trainable_ids = self.trainable_ids[input_ids]
81 new_ids = trainable_ids[trainable_ids == -1]
82
83 if new_ids.shape[0] == 0:
84 return
85
86 n = self.trainable_ids.shape[0]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0])
88
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A
92
93 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self)
95 if hasattr(self, 'lora_A'):
96 nn.init.zeros_(self.lora_A)
97 nn.init.normal_(self.lora_B)
98
99 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode)
101 if self.merge_weights and self.merged:
102 if self.r > 0:
103 mask = ~(self.trainable_ids == -1)
104 trainable_ids = self.trainable_ids[mask]
105 self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling
106 self.merged = False
107
108 def eval(self):
109 nn.Embedding.eval(self)
110 if self.merge_weights and not self.merged:
111 if self.r > 0:
112 mask = ~(self.trainable_ids == -1)
113 trainable_ids = self.trainable_ids[mask]
114 self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling
115 self.merged = True
116
117 def forward(self, input_ids: torch.Tensor):
118 result = nn.Embedding.forward(self, input_ids)
119
120 if self.r > 0 and not self.merged:
121 trainable_ids = self.trainable_ids[input_ids]
122 mask = ~(trainable_ids == -1)
123 trainable_ids = trainable_ids[mask]
124
125 after_A = F.embedding(
126 trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
127 self.norm_type, self.scale_grad_by_freq, self.sparse
128 )
129 result[mask] += (after_A @ self.lora_B.T) * self.scaling
130
131 return result