summaryrefslogtreecommitdiffstats
path: root/models/lora.py
blob: e506cfff09cd9172af1ec9b93571cf621ab337d2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Optional
import math

import torch
import torch.nn as nn


class LoraLayer():
    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout_p = lora_dropout

        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = nn.Identity()

        self.merged = False
        self.merge_weights = merge_weights


class LoraEmbedding(nn.Embedding, LoraLayer):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoraLayer.__init__(
            self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
        )

        self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)

        self.lora_A = nn.ParameterList()
        self.lora_B = nn.Linear(r, embedding_dim, bias=False)
        self.scaling = self.lora_alpha / self.r
        self.weight.requires_grad = False

        self.reset_parameters()

    def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
        n = min(self.num_embeddings, new_num_embeddings)

        new_emb = LoraEmbedding(
            new_num_embeddings,
            self.embedding_dim,
            self.r,
            self.lora_alpha,
            self.lora_dropout_p,
            device=self.weight.device,
            dtype=self.weight.dtype
        )
        if initializer_factor is not None:
            new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
        else:
            nn.init.zeros_(new_emb.weight.data)
        new_emb.weight.data[:n, :] = self.weight.data[:n, :]
        for param in self.lora_A:
            new_emb.lora_A.append(param)
        new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data
        new_emb.trainable_ids[:n] = self.trainable_ids[:n]

        return new_emb

    def mark_trainable(self, input_ids: torch.LongTensor):
        trainable_ids = self.trainable_ids[input_ids]
        new_ids = input_ids[trainable_ids == -1]

        if new_ids.shape[0] == 0:
            return

        n1 = len(self.lora_A)
        n2 = n1 + new_ids.shape[0]
        self.trainable_ids[new_ids] = torch.arange(n1, n2)
        for _ in new_ids:
            w = self.weight.new_zeros(self.r)
            self.lora_A.append(w)

        if len(self.lora_A) > 1:
            elems = torch.stack([param for param in self.lora_A])
            nn.init.kaiming_uniform_(elems, a=math.sqrt(5))

    def get_weights(self, input_ids: torch.Tensor):
        if len(input_ids.shape) != 1:
            return torch.stack([self.get_weights(batch) for batch in input_ids])

        weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))

        if not self.merged:
            trainable_ids = self.trainable_ids[input_ids]
            mask = ~(trainable_ids == -1)
            elems = [self.lora_A[id] for id in trainable_ids[mask]]

            if len(elems) != 0:
                w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
                weights[mask] = w.to(dtype=weights.dtype)

        return weights

    def persist(self):
        self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
        self.trainable_ids[:] = -1
        self.lora_A = nn.ParameterList()
        nn.init.zeros_(self.lora_B.weight)

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, "lora_A"):
            self.trainable_ids[:] = -1
            self.lora_A = nn.ParameterList()
            nn.init.zeros_(self.lora_B.weight)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        self.lora_A.train(mode)
        self.lora_B.train(mode)
        if not mode and self.merge_weights and not self.merged:
            self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
            self.merged = True
        elif self.merge_weights and self.merged:
            self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0]))
            self.merged = False

    def eval(self):
        nn.Embedding.eval(self)
        self.lora_A.eval()
        self.lora_B.eval()

    def forward(self, input_ids: torch.LongTensor):
        result = nn.Embedding.forward(self, input_ids)
        result += self.get_weights(input_ids)
        return result