summaryrefslogtreecommitdiffstats
path: root/models/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/lora.py')
-rw-r--r--models/lora.py77
1 files changed, 38 insertions, 39 deletions
diff --git a/models/lora.py b/models/lora.py
index 01a540b..e506cff 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -1,8 +1,8 @@
1from typing import Optional 1from typing import Optional
2import math
2 3
3import torch 4import torch
4import torch.nn as nn 5import torch.nn as nn
5import torch.nn.functional as F
6 6
7 7
8class LoraLayer(): 8class LoraLayer():
@@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights 42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 ) 43 )
44 44
45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) 45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
46 self.trainable_ids -= 1
47 46
48 if r > 0: 47 self.lora_A = nn.ParameterList()
49 self.lora_A = nn.ParameterList() 48 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
50 self.lora_B = nn.Linear(r, embedding_dim, bias=False) 49 self.scaling = self.lora_alpha / self.r
51 self.scaling = self.lora_alpha / self.r 50 self.weight.requires_grad = False
52 self.weight.requires_grad = False
53 51
54 self.reset_parameters() 52 self.reset_parameters()
55 53
@@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
70 else: 68 else:
71 nn.init.zeros_(new_emb.weight.data) 69 nn.init.zeros_(new_emb.weight.data)
72 new_emb.weight.data[:n, :] = self.weight.data[:n, :] 70 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
73 new_emb.lora_A = self.lora_A 71 for param in self.lora_A:
74 new_emb.lora_B = self.lora_B 72 new_emb.lora_A.append(param)
73 new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
75 new_emb.trainable_ids[:n] = self.trainable_ids[:n] 74 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
76 75
77 return new_emb 76 return new_emb
@@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
87 n2 = n1 + new_ids.shape[0] 86 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 87 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids: 88 for _ in new_ids:
90 self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) 89 w = self.weight.new_zeros(self.r)
90 self.lora_A.append(w)
91
92 if len(self.lora_A) > 1:
93 elems = torch.stack([param for param in self.lora_A])
94 nn.init.kaiming_uniform_(elems, a=math.sqrt(5))
91 95
92 def get_weights(self, input_ids: torch.Tensor): 96 def get_weights(self, input_ids: torch.Tensor):
93 if len(input_ids.shape) != 1: 97 if len(input_ids.shape) != 1:
94 return torch.stack([self.get_weights(batch) for batch in input_ids]) 98 return torch.stack([self.get_weights(batch) for batch in input_ids])
95 99
96 trainable_ids = self.trainable_ids[input_ids]
97 mask = ~(trainable_ids == -1)
98 trainable_ids = trainable_ids[mask]
99
100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) 100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
101 elems = [self.lora_A[id] for id in trainable_ids]
102 101
103 if len(elems) != 0: 102 if not self.merged:
104 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling 103 trainable_ids = self.trainable_ids[input_ids]
105 weights[mask] = w.to(dtype=weights.dtype) 104 mask = ~(trainable_ids == -1)
105 elems = [self.lora_A[id] for id in trainable_ids[mask]]
106
107 if len(elems) != 0:
108 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
109 weights[mask] = w.to(dtype=weights.dtype)
106 110
107 return weights 111 return weights
108 112
109 def persist(self): 113 def persist(self):
110 if self.r > 0: 114 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
111 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 115 self.trainable_ids[:] = -1
112 self.weight.data += weights 116 self.lora_A = nn.ParameterList()
113 self.trainable_ids[:] = -1 117 nn.init.zeros_(self.lora_B.weight)
114 self.lora_A = nn.ParameterList()
115 118
116 def reset_parameters(self): 119 def reset_parameters(self):
117 nn.Embedding.reset_parameters(self) 120 nn.Embedding.reset_parameters(self)
118 if hasattr(self, 'lora_A'): 121 if hasattr(self, "lora_A"):
119 self.trainable_ids[:] = -1 122 self.trainable_ids[:] = -1
120 self.lora_A = nn.ParameterList() 123 self.lora_A = nn.ParameterList()
121 nn.init.zeros_(self.lora_B.weight) 124 nn.init.zeros_(self.lora_B.weight)
122 125
123 def train(self, mode: bool = True): 126 def train(self, mode: bool = True):
124 nn.Embedding.train(self, mode) 127 nn.Embedding.train(self, mode)
125 if self.merge_weights and self.merged: 128 self.lora_A.train(mode)
126 if self.r > 0: 129 self.lora_B.train(mode)
127 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 130 if not mode and self.merge_weights and not self.merged:
128 self.weight.data -= weights 131 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
132 self.merged = True
133 elif self.merge_weights and self.merged:
134 self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
129 self.merged = False 135 self.merged = False
130 136
131 def eval(self): 137 def eval(self):
132 nn.Embedding.eval(self) 138 nn.Embedding.eval(self)
133 if self.merge_weights and not self.merged: 139 self.lora_A.eval()
134 if self.r > 0: 140 self.lora_B.eval()
135 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
136 self.weight.data += weights
137 self.merged = True
138 141
139 def forward(self, input_ids: torch.LongTensor): 142 def forward(self, input_ids: torch.LongTensor):
140 result = nn.Embedding.forward(self, input_ids) 143 result = nn.Embedding.forward(self, input_ids)
141 144 result += self.get_weights(input_ids)
142 if self.r > 0 and not self.merged:
143 weights = self.get_weights(input_ids)
144 result += weights
145
146 return result 145 return result