summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/lora.py145
1 files changed, 0 insertions, 145 deletions
diff --git a/models/lora.py b/models/lora.py
deleted file mode 100644
index e506cff..0000000
--- a/models/lora.py
+++ /dev/null
@@ -1,145 +0,0 @@
1from typing import Optional
2import math
3
4import torch
5import torch.nn as nn
6
7
8class LoraLayer():
9 def __init__(
10 self,
11 r: int,
12 lora_alpha: int,
13 lora_dropout: float,
14 merge_weights: bool,
15 ):
16 self.r = r
17 self.lora_alpha = lora_alpha
18 self.lora_dropout_p = lora_dropout
19
20 if lora_dropout > 0.:
21 self.lora_dropout = nn.Dropout(p=lora_dropout)
22 else:
23 self.lora_dropout = nn.Identity()
24
25 self.merged = False
26 self.merge_weights = merge_weights
27
28
29class LoraEmbedding(nn.Embedding, LoraLayer):
30 def __init__(
31 self,
32 num_embeddings: int,
33 embedding_dim: int,
34 r: int = 0,
35 lora_alpha: int = 1,
36 lora_dropout: float = 0.0,
37 merge_weights: bool = True,
38 **kwargs
39 ):
40 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
41 LoraLayer.__init__(
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 )
44
45 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
46
47 self.lora_A = nn.ParameterList()
48 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
49 self.scaling = self.lora_alpha / self.r
50 self.weight.requires_grad = False
51
52 self.reset_parameters()
53
54 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
55 n = min(self.num_embeddings, new_num_embeddings)
56
57 new_emb = LoraEmbedding(
58 new_num_embeddings,
59 self.embedding_dim,
60 self.r,
61 self.lora_alpha,
62 self.lora_dropout_p,
63 device=self.weight.device,
64 dtype=self.weight.dtype
65 )
66 if initializer_factor is not None:
67 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
68 else:
69 nn.init.zeros_(new_emb.weight.data)
70 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
71 for param in self.lora_A:
72 new_emb.lora_A.append(param)
73 new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
74 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
75
76 return new_emb
77
78 def mark_trainable(self, input_ids: torch.LongTensor):
79 trainable_ids = self.trainable_ids[input_ids]
80 new_ids = input_ids[trainable_ids == -1]
81
82 if new_ids.shape[0] == 0:
83 return
84
85 n1 = len(self.lora_A)
86 n2 = n1 + new_ids.shape[0]
87 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 for _ in new_ids:
89 w = self.weight.new_zeros(self.r)
90 self.lora_A.append(w)
91
92 if len(self.lora_A) > 1:
93 elems = torch.stack([param for param in self.lora_A])
94 nn.init.kaiming_uniform_(elems, a=math.sqrt(5))
95
96 def get_weights(self, input_ids: torch.Tensor):
97 if len(input_ids.shape) != 1:
98 return torch.stack([self.get_weights(batch) for batch in input_ids])
99
100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
101
102 if not self.merged:
103 trainable_ids = self.trainable_ids[input_ids]
104 mask = ~(trainable_ids == -1)
105 elems = [self.lora_A[id] for id in trainable_ids[mask]]
106
107 if len(elems) != 0:
108 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
109 weights[mask] = w.to(dtype=weights.dtype)
110
111 return weights
112
113 def persist(self):
114 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
115 self.trainable_ids[:] = -1
116 self.lora_A = nn.ParameterList()
117 nn.init.zeros_(self.lora_B.weight)
118
119 def reset_parameters(self):
120 nn.Embedding.reset_parameters(self)
121 if hasattr(self, "lora_A"):
122 self.trainable_ids[:] = -1
123 self.lora_A = nn.ParameterList()
124 nn.init.zeros_(self.lora_B.weight)
125
126 def train(self, mode: bool = True):
127 nn.Embedding.train(self, mode)
128 self.lora_A.train(mode)
129 self.lora_B.train(mode)
130 if not mode and self.merge_weights and not self.merged:
131 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
132 self.merged = True
133 elif self.merge_weights and self.merged:
134 self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
135 self.merged = False
136
137 def eval(self):
138 nn.Embedding.eval(self)
139 self.lora_A.eval()
140 self.lora_B.eval()
141
142 def forward(self, input_ids: torch.LongTensor):
143 result = nn.Embedding.forward(self, input_ids)
144 result += self.get_weights(input_ids)
145 return result