diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 14:45:37 +0200 |
commit | 3924055ed24da9b6995303cd36282eb558ba0bf0 (patch) | |
tree | 4fed8dabcde2236e1a1e8f5738b2a0bdcfd4513b /models/lora.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.gz textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.tar.bz2 textual-inversion-diff-3924055ed24da9b6995303cd36282eb558ba0bf0.zip |
Fix
Diffstat (limited to 'models/lora.py')
-rw-r--r-- | models/lora.py | 77 |
1 files changed, 38 insertions, 39 deletions
diff --git a/models/lora.py b/models/lora.py index 01a540b..e506cff 100644 --- a/models/lora.py +++ b/models/lora.py | |||
@@ -1,8 +1,8 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | import math | ||
2 | 3 | ||
3 | import torch | 4 | import torch |
4 | import torch.nn as nn | 5 | import torch.nn as nn |
5 | import torch.nn.functional as F | ||
6 | 6 | ||
7 | 7 | ||
8 | class LoraLayer(): | 8 | class LoraLayer(): |
@@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
43 | ) | 43 | ) |
44 | 44 | ||
45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) |
46 | self.trainable_ids -= 1 | ||
47 | 46 | ||
48 | if r > 0: | 47 | self.lora_A = nn.ParameterList() |
49 | self.lora_A = nn.ParameterList() | 48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) |
50 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | 49 | self.scaling = self.lora_alpha / self.r |
51 | self.scaling = self.lora_alpha / self.r | 50 | self.weight.requires_grad = False |
52 | self.weight.requires_grad = False | ||
53 | 51 | ||
54 | self.reset_parameters() | 52 | self.reset_parameters() |
55 | 53 | ||
@@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
70 | else: | 68 | else: |
71 | nn.init.zeros_(new_emb.weight.data) | 69 | nn.init.zeros_(new_emb.weight.data) |
72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | 70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] |
73 | new_emb.lora_A = self.lora_A | 71 | for param in self.lora_A: |
74 | new_emb.lora_B = self.lora_B | 72 | new_emb.lora_A.append(param) |
73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | 74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] |
76 | 75 | ||
77 | return new_emb | 76 | return new_emb |
@@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
87 | n2 = n1 + new_ids.shape[0] | 86 | n2 = n1 + new_ids.shape[0] |
88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
89 | for _ in new_ids: | 88 | for _ in new_ids: |
90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) | 89 | w = self.weight.new_zeros(self.r) |
90 | self.lora_A.append(w) | ||
91 | |||
92 | if len(self.lora_A) > 1: | ||
93 | elems = torch.stack([param for param in self.lora_A]) | ||
94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
91 | 95 | ||
92 | def get_weights(self, input_ids: torch.Tensor): | 96 | def get_weights(self, input_ids: torch.Tensor): |
93 | if len(input_ids.shape) != 1: | 97 | if len(input_ids.shape) != 1: |
94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | 98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) |
95 | 99 | ||
96 | trainable_ids = self.trainable_ids[input_ids] | ||
97 | mask = ~(trainable_ids == -1) | ||
98 | trainable_ids = trainable_ids[mask] | ||
99 | |||
100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) |
101 | elems = [self.lora_A[id] for id in trainable_ids] | ||
102 | 101 | ||
103 | if len(elems) != 0: | 102 | if not self.merged: |
104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 103 | trainable_ids = self.trainable_ids[input_ids] |
105 | weights[mask] = w.to(dtype=weights.dtype) | 104 | mask = ~(trainable_ids == -1) |
105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
106 | |||
107 | if len(elems) != 0: | ||
108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
109 | weights[mask] = w.to(dtype=weights.dtype) | ||
106 | 110 | ||
107 | return weights | 111 | return weights |
108 | 112 | ||
109 | def persist(self): | 113 | def persist(self): |
110 | if self.r > 0: | 114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 115 | self.trainable_ids[:] = -1 |
112 | self.weight.data += weights | 116 | self.lora_A = nn.ParameterList() |
113 | self.trainable_ids[:] = -1 | 117 | nn.init.zeros_(self.lora_B.weight) |
114 | self.lora_A = nn.ParameterList() | ||
115 | 118 | ||
116 | def reset_parameters(self): | 119 | def reset_parameters(self): |
117 | nn.Embedding.reset_parameters(self) | 120 | nn.Embedding.reset_parameters(self) |
118 | if hasattr(self, 'lora_A'): | 121 | if hasattr(self, "lora_A"): |
119 | self.trainable_ids[:] = -1 | 122 | self.trainable_ids[:] = -1 |
120 | self.lora_A = nn.ParameterList() | 123 | self.lora_A = nn.ParameterList() |
121 | nn.init.zeros_(self.lora_B.weight) | 124 | nn.init.zeros_(self.lora_B.weight) |
122 | 125 | ||
123 | def train(self, mode: bool = True): | 126 | def train(self, mode: bool = True): |
124 | nn.Embedding.train(self, mode) | 127 | nn.Embedding.train(self, mode) |
125 | if self.merge_weights and self.merged: | 128 | self.lora_A.train(mode) |
126 | if self.r > 0: | 129 | self.lora_B.train(mode) |
127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 130 | if not mode and self.merge_weights and not self.merged: |
128 | self.weight.data -= weights | 131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
132 | self.merged = True | ||
133 | elif self.merge_weights and self.merged: | ||
134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
129 | self.merged = False | 135 | self.merged = False |
130 | 136 | ||
131 | def eval(self): | 137 | def eval(self): |
132 | nn.Embedding.eval(self) | 138 | nn.Embedding.eval(self) |
133 | if self.merge_weights and not self.merged: | 139 | self.lora_A.eval() |
134 | if self.r > 0: | 140 | self.lora_B.eval() |
135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
136 | self.weight.data += weights | ||
137 | self.merged = True | ||
138 | 141 | ||
139 | def forward(self, input_ids: torch.LongTensor): | 142 | def forward(self, input_ids: torch.LongTensor): |
140 | result = nn.Embedding.forward(self, input_ids) | 143 | result = nn.Embedding.forward(self, input_ids) |
141 | 144 | result += self.get_weights(input_ids) | |
142 | if self.r > 0 and not self.merged: | ||
143 | weights = self.get_weights(input_ids) | ||
144 | result += weights | ||
145 | |||
146 | return result | 145 | return result |