diff options
Diffstat (limited to 'models/lora.py')
| -rw-r--r-- | models/lora.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py | |||
| @@ -0,0 +1,131 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | import torch.nn.functional as F | ||
| 6 | |||
| 7 | |||
| 8 | class LoraLayer(): | ||
| 9 | def __init__( | ||
| 10 | self, | ||
| 11 | r: int, | ||
| 12 | lora_alpha: int, | ||
| 13 | lora_dropout: float, | ||
| 14 | merge_weights: bool, | ||
| 15 | ): | ||
| 16 | self.r = r | ||
| 17 | self.lora_alpha = lora_alpha | ||
| 18 | self.lora_dropout_p = lora_dropout | ||
| 19 | |||
| 20 | if lora_dropout > 0.: | ||
| 21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
| 22 | else: | ||
| 23 | self.lora_dropout = nn.Identity() | ||
| 24 | |||
| 25 | self.merged = False | ||
| 26 | self.merge_weights = merge_weights | ||
| 27 | |||
| 28 | |||
| 29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
| 30 | def __init__( | ||
| 31 | self, | ||
| 32 | num_embeddings: int, | ||
| 33 | embedding_dim: int, | ||
| 34 | r: int = 0, | ||
| 35 | lora_alpha: int = 1, | ||
| 36 | lora_dropout: float = 0.0, | ||
| 37 | merge_weights: bool = True, | ||
| 38 | **kwargs | ||
| 39 | ): | ||
| 40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
| 41 | LoraLayer.__init__( | ||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
| 43 | ) | ||
| 44 | |||
| 45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | ||
| 46 | self.trainable_ids -= 1 | ||
| 47 | |||
| 48 | if r > 0: | ||
| 49 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) | ||
| 50 | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | ||
| 51 | self.scaling = self.lora_alpha / self.r | ||
| 52 | self.weight.requires_grad = False | ||
| 53 | |||
| 54 | self.reset_parameters() | ||
| 55 | |||
| 56 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
| 57 | n = min(self.num_embeddings, new_num_embeddings) | ||
| 58 | |||
| 59 | new_emb = LoraEmbedding( | ||
| 60 | new_num_embeddings, | ||
| 61 | self.embedding_dim, | ||
| 62 | self.r, | ||
| 63 | self.lora_alpha, | ||
| 64 | self.lora_dropout_p, | ||
| 65 | device=self.weight.device, | ||
| 66 | dtype=self.weight.dtype | ||
| 67 | ) | ||
| 68 | if initializer_factor is not None: | ||
| 69 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 70 | else: | ||
| 71 | nn.init.zeros_(new_emb.weight.data) | ||
| 72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
| 73 | new_emb.lora_A = self.lora_A | ||
| 74 | new_emb.lora_B = self.lora_B | ||
| 75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
| 76 | |||
| 77 | return new_emb | ||
| 78 | |||
| 79 | def mark_trainable(self, input_ids): | ||
| 80 | trainable_ids = self.trainable_ids[input_ids] | ||
| 81 | new_ids = trainable_ids[trainable_ids == -1] | ||
| 82 | |||
| 83 | if new_ids.shape[0] == 0: | ||
| 84 | return | ||
| 85 | |||
| 86 | n = self.trainable_ids.shape[0] | ||
| 87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | ||
| 88 | |||
| 89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | ||
| 90 | lora_A.data[:n] = self.lora_A.data | ||
| 91 | self.lora_A = lora_A | ||
| 92 | |||
| 93 | def reset_parameters(self): | ||
| 94 | nn.Embedding.reset_parameters(self) | ||
| 95 | if hasattr(self, 'lora_A'): | ||
| 96 | nn.init.zeros_(self.lora_A) | ||
| 97 | nn.init.normal_(self.lora_B) | ||
| 98 | |||
| 99 | def train(self, mode: bool = True): | ||
| 100 | nn.Embedding.train(self, mode) | ||
| 101 | if self.merge_weights and self.merged: | ||
| 102 | if self.r > 0: | ||
| 103 | mask = ~(self.trainable_ids == -1) | ||
| 104 | trainable_ids = self.trainable_ids[mask] | ||
| 105 | self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling | ||
| 106 | self.merged = False | ||
| 107 | |||
| 108 | def eval(self): | ||
| 109 | nn.Embedding.eval(self) | ||
| 110 | if self.merge_weights and not self.merged: | ||
| 111 | if self.r > 0: | ||
| 112 | mask = ~(self.trainable_ids == -1) | ||
| 113 | trainable_ids = self.trainable_ids[mask] | ||
| 114 | self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling | ||
| 115 | self.merged = True | ||
| 116 | |||
| 117 | def forward(self, input_ids: torch.Tensor): | ||
| 118 | result = nn.Embedding.forward(self, input_ids) | ||
| 119 | |||
| 120 | if self.r > 0 and not self.merged: | ||
| 121 | trainable_ids = self.trainable_ids[input_ids] | ||
| 122 | mask = ~(trainable_ids == -1) | ||
| 123 | trainable_ids = trainable_ids[mask] | ||
| 124 | |||
| 125 | after_A = F.embedding( | ||
| 126 | trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, | ||
| 127 | self.norm_type, self.scale_grad_by_freq, self.sparse | ||
| 128 | ) | ||
| 129 | result[mask] += (after_A @ self.lora_B.T) * self.scaling | ||
| 130 | |||
| 131 | return result | ||
