diff options
author | Volpeon <git@volpeon.ink> | 2023-05-12 18:06:13 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-12 18:06:13 +0200 |
commit | daba86ebbfa821f3c3227bcfbcbd532051e793e7 (patch) | |
tree | 37b00a8216b6298004b75bd06ffc5c8ff76bce8e /models | |
parent | Update (diff) | |
download | textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.tar.gz textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.tar.bz2 textual-inversion-diff-daba86ebbfa821f3c3227bcfbcbd532051e793e7.zip |
Fix for latest PEFT
Diffstat (limited to 'models')
-rw-r--r-- | models/lora.py | 145 |
1 files changed, 0 insertions, 145 deletions
diff --git a/models/lora.py b/models/lora.py deleted file mode 100644 index e506cff..0000000 --- a/models/lora.py +++ /dev/null | |||
@@ -1,145 +0,0 @@ | |||
1 | from typing import Optional | ||
2 | import math | ||
3 | |||
4 | import torch | ||
5 | import torch.nn as nn | ||
6 | |||
7 | |||
8 | class LoraLayer(): | ||
9 | def __init__( | ||
10 | self, | ||
11 | r: int, | ||
12 | lora_alpha: int, | ||
13 | lora_dropout: float, | ||
14 | merge_weights: bool, | ||
15 | ): | ||
16 | self.r = r | ||
17 | self.lora_alpha = lora_alpha | ||
18 | self.lora_dropout_p = lora_dropout | ||
19 | |||
20 | if lora_dropout > 0.: | ||
21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
22 | else: | ||
23 | self.lora_dropout = nn.Identity() | ||
24 | |||
25 | self.merged = False | ||
26 | self.merge_weights = merge_weights | ||
27 | |||
28 | |||
29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
30 | def __init__( | ||
31 | self, | ||
32 | num_embeddings: int, | ||
33 | embedding_dim: int, | ||
34 | r: int = 0, | ||
35 | lora_alpha: int = 1, | ||
36 | lora_dropout: float = 0.0, | ||
37 | merge_weights: bool = True, | ||
38 | **kwargs | ||
39 | ): | ||
40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
41 | LoraLayer.__init__( | ||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
43 | ) | ||
44 | |||
45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
46 | |||
47 | self.lora_A = nn.ParameterList() | ||
48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | ||
49 | self.scaling = self.lora_alpha / self.r | ||
50 | self.weight.requires_grad = False | ||
51 | |||
52 | self.reset_parameters() | ||
53 | |||
54 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
55 | n = min(self.num_embeddings, new_num_embeddings) | ||
56 | |||
57 | new_emb = LoraEmbedding( | ||
58 | new_num_embeddings, | ||
59 | self.embedding_dim, | ||
60 | self.r, | ||
61 | self.lora_alpha, | ||
62 | self.lora_dropout_p, | ||
63 | device=self.weight.device, | ||
64 | dtype=self.weight.dtype | ||
65 | ) | ||
66 | if initializer_factor is not None: | ||
67 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
68 | else: | ||
69 | nn.init.zeros_(new_emb.weight.data) | ||
70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
71 | for param in self.lora_A: | ||
72 | new_emb.lora_A.append(param) | ||
73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
75 | |||
76 | return new_emb | ||
77 | |||
78 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
79 | trainable_ids = self.trainable_ids[input_ids] | ||
80 | new_ids = input_ids[trainable_ids == -1] | ||
81 | |||
82 | if new_ids.shape[0] == 0: | ||
83 | return | ||
84 | |||
85 | n1 = len(self.lora_A) | ||
86 | n2 = n1 + new_ids.shape[0] | ||
87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
88 | for _ in new_ids: | ||
89 | w = self.weight.new_zeros(self.r) | ||
90 | self.lora_A.append(w) | ||
91 | |||
92 | if len(self.lora_A) > 1: | ||
93 | elems = torch.stack([param for param in self.lora_A]) | ||
94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
95 | |||
96 | def get_weights(self, input_ids: torch.Tensor): | ||
97 | if len(input_ids.shape) != 1: | ||
98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | ||
99 | |||
100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
101 | |||
102 | if not self.merged: | ||
103 | trainable_ids = self.trainable_ids[input_ids] | ||
104 | mask = ~(trainable_ids == -1) | ||
105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
106 | |||
107 | if len(elems) != 0: | ||
108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
109 | weights[mask] = w.to(dtype=weights.dtype) | ||
110 | |||
111 | return weights | ||
112 | |||
113 | def persist(self): | ||
114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
115 | self.trainable_ids[:] = -1 | ||
116 | self.lora_A = nn.ParameterList() | ||
117 | nn.init.zeros_(self.lora_B.weight) | ||
118 | |||
119 | def reset_parameters(self): | ||
120 | nn.Embedding.reset_parameters(self) | ||
121 | if hasattr(self, "lora_A"): | ||
122 | self.trainable_ids[:] = -1 | ||
123 | self.lora_A = nn.ParameterList() | ||
124 | nn.init.zeros_(self.lora_B.weight) | ||
125 | |||
126 | def train(self, mode: bool = True): | ||
127 | nn.Embedding.train(self, mode) | ||
128 | self.lora_A.train(mode) | ||
129 | self.lora_B.train(mode) | ||
130 | if not mode and self.merge_weights and not self.merged: | ||
131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
132 | self.merged = True | ||
133 | elif self.merge_weights and self.merged: | ||
134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
135 | self.merged = False | ||
136 | |||
137 | def eval(self): | ||
138 | nn.Embedding.eval(self) | ||
139 | self.lora_A.eval() | ||
140 | self.lora_B.eval() | ||
141 | |||
142 | def forward(self, input_ids: torch.LongTensor): | ||
143 | result = nn.Embedding.forward(self, input_ids) | ||
144 | result += self.get_weights(input_ids) | ||
145 | return result | ||