From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- models/lora.py | 131 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 models/lora.py (limited to 'models/lora.py') diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py @@ -0,0 +1,131 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoraLayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + self.lora_dropout_p = lora_dropout + + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = nn.Identity() + + self.merged = False + self.merge_weights = merge_weights + + +class LoraEmbedding(nn.Embedding, LoraLayer): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + merge_weights: bool = True, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoraLayer.__init__( + self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights + ) + + self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) + self.trainable_ids -= 1 + + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) + self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False + + self.reset_parameters() + + def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + n = min(self.num_embeddings, new_num_embeddings) + + new_emb = LoraEmbedding( + new_num_embeddings, + self.embedding_dim, + self.r, + self.lora_alpha, + self.lora_dropout_p, + device=self.weight.device, + dtype=self.weight.dtype + ) + if initializer_factor is not None: + new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + else: + nn.init.zeros_(new_emb.weight.data) + new_emb.weight.data[:n, :] = self.weight.data[:n, :] + new_emb.lora_A = self.lora_A + new_emb.lora_B = self.lora_B + new_emb.trainable_ids[:n] = self.trainable_ids[:n] + + return new_emb + + def mark_trainable(self, input_ids): + trainable_ids = self.trainable_ids[input_ids] + new_ids = trainable_ids[trainable_ids == -1] + + if new_ids.shape[0] == 0: + return + + n = self.trainable_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + + lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) + lora_A.data[:n] = self.lora_A.data + self.lora_A = lora_A + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, 'lora_A'): + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if self.merge_weights and self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling + self.merged = False + + def eval(self): + nn.Embedding.eval(self) + if self.merge_weights and not self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, input_ids: torch.Tensor): + result = nn.Embedding.forward(self, input_ids) + + if self.r > 0 and not self.merged: + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + trainable_ids = trainable_ids[mask] + + after_A = F.embedding( + trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + result[mask] += (after_A @ self.lora_B.T) * self.scaling + + return result -- cgit v1.2.3-54-g00ecf